上传文件至 kg_lab_6.13
6.18 更新数据配置路径统一,和前端demo
This commit is contained in:
@@ -4,6 +4,8 @@ from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.output_parsers import JsonOutputParser
|
||||
|
||||
from extraction_info import neo4j_url, neo4j_username, neo4j_password
|
||||
|
||||
qwen_llm = ChatOpenAI(
|
||||
openai_api_base="https://api.siliconflow.cn/v1",
|
||||
model_name="Qwen/Qwen2.5-72B-Instruct",
|
||||
@@ -111,9 +113,9 @@ from langchain_community.graphs import Neo4jGraph
|
||||
|
||||
|
||||
graph = Neo4jGraph(
|
||||
url="bolt://172.20.0.145:7687",
|
||||
username="neo4j",
|
||||
password="password",
|
||||
url = neo4j_url,
|
||||
username = neo4j_username,
|
||||
password = neo4j_password,
|
||||
)
|
||||
|
||||
graph.refresh_schema()
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
# import os
|
||||
|
||||
# BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# info_data_json = os.path.join(BASE_DIR, 'data/data.json')
|
||||
# info_data_txt = os.path.join(BASE_DIR, 'data/data.txt')
|
||||
# info_faiss_archived = os.path.join(BASE_DIR, 'data/faiss_data/data')
|
||||
|
||||
# # Neo4j
|
||||
# neo4j_url = "bolt://172.20.0.145:7687"
|
||||
# neo4j_username = "neo4j"
|
||||
# neo4j_password = "password"
|
||||
|
||||
|
||||
import os
|
||||
import json
|
||||
|
||||
# 获取当前目录
|
||||
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
# 读取 config.json 配置文件
|
||||
config_path = os.path.join(BASE_DIR, 'config.json')
|
||||
with open(config_path, 'r', encoding='utf-8') as f:
|
||||
config = json.load(f)
|
||||
|
||||
# 解析路径配置
|
||||
info_data_json = os.path.join(BASE_DIR, config["info_data_json"])
|
||||
info_data_txt = os.path.join(BASE_DIR, config["info_data_txt"])
|
||||
info_faiss_archived = os.path.join(BASE_DIR, config["info_faiss_archived"])
|
||||
|
||||
# 解析 Neo4j 配置
|
||||
neo4j_url = config["neo4j"]["url"]
|
||||
neo4j_username = config["neo4j"]["username"]
|
||||
neo4j_password = config["neo4j"]["password"]
|
||||
|
||||
|
||||
+18
-14
@@ -1,8 +1,13 @@
|
||||
from chains_lab import Problem_rewrite
|
||||
from chains_lab import booway_cypher_chain
|
||||
from chains_lab import question_answer, question_answer_calculation
|
||||
from vector_lab import intersection_of_three_lists
|
||||
from utils import find_target_item, find_target_items, pre_mapping, pre_mapping2
|
||||
from utils import extract_concrete_info, extract_query_prefix_list, split_chinese_bracketed_phrases
|
||||
import json
|
||||
|
||||
from extraction_info import info_data_json
|
||||
|
||||
# 样例
|
||||
# input_str1 = "杆塔总基数是多少?"
|
||||
# input_str2 = "单回路长度是多少?"
|
||||
@@ -13,21 +18,15 @@ import json
|
||||
# input_str7 = "计算一下本体工程机械费"
|
||||
# input_str8 = "项目建设技术服务费合计"
|
||||
|
||||
# 初始化
|
||||
# 初始化chians
|
||||
problem_rewrite = Problem_rewrite()
|
||||
|
||||
from utils import extract_concrete_info, extract_query_prefix_list, split_chinese_bracketed_phrases
|
||||
|
||||
from chains_lab import question_answer, question_answer_calculation
|
||||
|
||||
qa_chains = question_answer()
|
||||
|
||||
calculation_chains = question_answer_calculation()
|
||||
|
||||
from chains_lab import booway_cypher_chain
|
||||
|
||||
|
||||
# 加载数据
|
||||
with open('./data/data.json', 'r', encoding='utf-8') as file:
|
||||
with open(info_data_json, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
|
||||
print("📥 请输入查询内容,输入 'exit' 可退出程序。\n")
|
||||
@@ -91,11 +90,16 @@ while True:
|
||||
print(ques)
|
||||
|
||||
retriever_info = []
|
||||
for idx, i in enumerate(ques):
|
||||
response = booway_cypher_chain.invoke(i)
|
||||
temp = response.get("result")
|
||||
retriever_info.append(temp)
|
||||
|
||||
for i in ques[:12]:
|
||||
try:
|
||||
response = booway_cypher_chain.invoke(i)
|
||||
temp = response.get("result")
|
||||
# todo: 重复筛选策略
|
||||
retriever_info.append(temp)
|
||||
except Exception as e:
|
||||
print(f"处理问题时出错: {e}")
|
||||
retriever_info.append(None)
|
||||
|
||||
|
||||
retriever_keywords = ques_info[0]
|
||||
calculation = ques_info[-1]
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
import streamlit as st
|
||||
import json
|
||||
|
||||
from chains_lab import Problem_rewrite, booway_cypher_chain
|
||||
from chains_lab import question_answer, question_answer_calculation
|
||||
from vector_lab import intersection_of_three_lists
|
||||
from utils import find_target_item, find_target_items, pre_mapping, pre_mapping2
|
||||
from utils import extract_concrete_info, extract_query_prefix_list, split_chinese_bracketed_phrases
|
||||
from extraction_info import info_data_json
|
||||
|
||||
# 初始化 chains
|
||||
problem_rewrite = Problem_rewrite()
|
||||
qa_chains = question_answer()
|
||||
calculation_chains = question_answer_calculation()
|
||||
|
||||
# 加载数据
|
||||
with open(info_data_json, 'r', encoding='utf-8') as file:
|
||||
data = json.load(file)
|
||||
|
||||
# Streamlit 页面设置
|
||||
st.set_page_config(page_title="图谱问答系统", layout="wide")
|
||||
st.title("💬 图谱问答系统")
|
||||
st.markdown("请在下方输入你的查询问题:")
|
||||
|
||||
# 输入框
|
||||
# input_str1 = "杆塔总基数是多少?"
|
||||
# input_str2 = "单回路长度是多少?"
|
||||
# input_str3 = "计算一下角钢塔的塔材装材费"
|
||||
# input_str4 = "计算一下土石方总量"
|
||||
# input_str5 = "板式塔基的各类基础数量占总塔基数比例是多少?"
|
||||
# input_str6 = "基础混凝土总量是多少"
|
||||
# input_str7 = "计算一下本体工程机械费"
|
||||
# input_str8 = "项目建设技术服务费合计"
|
||||
input_str = st.text_input("🔍 输入问题", placeholder="例如:计算一下角钢塔的塔材装材费?")
|
||||
|
||||
# 执行按钮
|
||||
if st.button("🚀 开始查询") and input_str.strip():
|
||||
try:
|
||||
results = intersection_of_three_lists(input_str)
|
||||
if not results:
|
||||
st.warning("⚠️ 无法从向量中获取候选项。")
|
||||
else:
|
||||
retriever = results[0]
|
||||
st.markdown(f"➡️ **匹配向量检索结果**:`{retriever}`")
|
||||
|
||||
# 重写问题,提取关键词
|
||||
keywords = problem_rewrite.invoke({
|
||||
"query": input_str,
|
||||
"retriever": retriever
|
||||
})
|
||||
|
||||
st.markdown(f"🧠 **提取关键词**: `{keywords}`")
|
||||
|
||||
# 预映射图数据库结构
|
||||
input_neo4j = pre_mapping(keywords, data)
|
||||
|
||||
st.markdown("📊 **图谱相关知识输出**:")
|
||||
st.code(input_neo4j, language='json')
|
||||
|
||||
if isinstance(input_neo4j, str):
|
||||
response = booway_cypher_chain.invoke(input_neo4j)
|
||||
cypher_query = response.get("intermediate_steps")[0]
|
||||
cypher_result = response.get("result")
|
||||
|
||||
# print(cypher_query)
|
||||
# 可视化 Cypher 查询语句
|
||||
with st.expander("🔐 Cypher 查询语句"):
|
||||
st.code(cypher_query['query'], language='cypher')
|
||||
|
||||
# 问答环节
|
||||
final_result = qa_chains.invoke({
|
||||
"query": input_str,
|
||||
"retriever_keywords": keywords,
|
||||
"retriever_info": cypher_result
|
||||
})
|
||||
|
||||
st.success("📊 **图谱最终检索输出:**")
|
||||
st.write(final_result)
|
||||
|
||||
elif isinstance(input_neo4j, list):
|
||||
ques = extract_query_prefix_list(input_neo4j)
|
||||
temp = extract_concrete_info(input_neo4j)
|
||||
ques_info = split_chinese_bracketed_phrases(temp[0])
|
||||
|
||||
st.markdown("🔎 **问题拆分结果:**")
|
||||
st.write(ques)
|
||||
|
||||
retriever_info = []
|
||||
cypher_info = []
|
||||
for i in ques[:12]:
|
||||
try:
|
||||
response = booway_cypher_chain.invoke(i)
|
||||
retriever_info.append(response.get("result"))
|
||||
cypher_info.append(response.get("intermediate_steps")[0])
|
||||
except Exception as e:
|
||||
st.error(f"处理问题时出错:{e}")
|
||||
retriever_info.append(None)
|
||||
|
||||
retriever_keywords = ques_info[0]
|
||||
calculation = ques_info[-1]
|
||||
|
||||
with st.expander("🔐 Cypher 查询语句"):
|
||||
for idx, cypher in enumerate(cypher_info):
|
||||
st.code(cypher['query'], language='cypher')
|
||||
|
||||
st.markdown("🧮 **表达式输出:**")
|
||||
st.write(calculation)
|
||||
|
||||
final_result = calculation_chains.invoke({
|
||||
"query": input_neo4j,
|
||||
"retriever_keywords": retriever_keywords,
|
||||
"calculation": calculation,
|
||||
"retriever_info": retriever_info
|
||||
})
|
||||
|
||||
st.success("📊 **图谱最终检索输出:**")
|
||||
st.write(final_result)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"❌ 发生错误:{e}")
|
||||
|
||||
# streamlit run streamlit_ceshi.py --server.port 2336
|
||||
@@ -7,6 +7,8 @@ import requests
|
||||
import httpx
|
||||
import logging
|
||||
|
||||
from extraction_info import info_data_txt, info_faiss_archived
|
||||
|
||||
class SiliconFlowEmbeddings(Embeddings):
|
||||
"""SiliconFlow嵌入模型封装"""
|
||||
def __init__(self, api_key: str, model: str = "bge-m3"):
|
||||
@@ -39,15 +41,15 @@ class SiliconFlowEmbeddings(Embeddings):
|
||||
# embeddings = Embedding(url="http://10.1.16.39:9995/v1", api_key="xxx", model_name="bge-m3")
|
||||
embeddings = SiliconFlowEmbeddings(api_key="xxx")
|
||||
|
||||
with open("./data/data.txt", 'r', encoding='utf-8') as file:
|
||||
with open(info_data_txt, 'r', encoding='utf-8') as file:
|
||||
txt_list = [line.strip() for line in file]
|
||||
|
||||
# embedding_path = "/data/Z_LLM_data/Embed_data/bge-m3"
|
||||
# embeddings = HuggingFaceEmbeddings(model_name=embedding_path)
|
||||
|
||||
faiss_archived = "./data/faiss_data/data"
|
||||
# faiss_archived = "./data/faiss_data/data"
|
||||
vectorstore_txt_faiss = FAISS.from_texts(txt_list, embeddings)
|
||||
vectorstore_txt_faiss.save_local(faiss_archived)
|
||||
vectorstore_txt_faiss.save_local(info_faiss_archived)
|
||||
|
||||
retriever_txt_faiss1 = vectorstore_txt_faiss.as_retriever(search_kwargs={"k":3})
|
||||
retriever_txt_faiss2 = vectorstore_txt_faiss.as_retriever(
|
||||
|
||||
Reference in New Issue
Block a user