上传文件至 kg_lab_6.13

6.18 更新数据配置路径统一,和前端demo
This commit is contained in:
2025-06-18 16:02:47 +08:00
parent fad7c5de4a
commit 517691c2d6
5 changed files with 185 additions and 20 deletions
+5 -3
View File
@@ -4,6 +4,8 @@ from langchain_core.prompts import ChatPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser
from extraction_info import neo4j_url, neo4j_username, neo4j_password
qwen_llm = ChatOpenAI(
openai_api_base="https://api.siliconflow.cn/v1",
model_name="Qwen/Qwen2.5-72B-Instruct",
@@ -111,9 +113,9 @@ from langchain_community.graphs import Neo4jGraph
graph = Neo4jGraph(
url="bolt://172.20.0.145:7687",
username="neo4j",
password="password",
url = neo4j_url,
username = neo4j_username,
password = neo4j_password,
)
graph.refresh_schema()
+35
View File
@@ -0,0 +1,35 @@
# import os
# BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# info_data_json = os.path.join(BASE_DIR, 'data/data.json')
# info_data_txt = os.path.join(BASE_DIR, 'data/data.txt')
# info_faiss_archived = os.path.join(BASE_DIR, 'data/faiss_data/data')
# # Neo4j
# neo4j_url = "bolt://172.20.0.145:7687"
# neo4j_username = "neo4j"
# neo4j_password = "password"
import os
import json
# 获取当前目录
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
# 读取 config.json 配置文件
config_path = os.path.join(BASE_DIR, 'config.json')
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# 解析路径配置
info_data_json = os.path.join(BASE_DIR, config["info_data_json"])
info_data_txt = os.path.join(BASE_DIR, config["info_data_txt"])
info_faiss_archived = os.path.join(BASE_DIR, config["info_faiss_archived"])
# 解析 Neo4j 配置
neo4j_url = config["neo4j"]["url"]
neo4j_username = config["neo4j"]["username"]
neo4j_password = config["neo4j"]["password"]
+18 -14
View File
@@ -1,8 +1,13 @@
from chains_lab import Problem_rewrite
from chains_lab import booway_cypher_chain
from chains_lab import question_answer, question_answer_calculation
from vector_lab import intersection_of_three_lists
from utils import find_target_item, find_target_items, pre_mapping, pre_mapping2
from utils import extract_concrete_info, extract_query_prefix_list, split_chinese_bracketed_phrases
import json
from extraction_info import info_data_json
# 样例
# input_str1 = "杆塔总基数是多少?"
# input_str2 = "单回路长度是多少?"
@@ -13,21 +18,15 @@ import json
# input_str7 = "计算一下本体工程机械费"
# input_str8 = "项目建设技术服务费合计"
# 初始化
# 初始化chians
problem_rewrite = Problem_rewrite()
from utils import extract_concrete_info, extract_query_prefix_list, split_chinese_bracketed_phrases
from chains_lab import question_answer, question_answer_calculation
qa_chains = question_answer()
calculation_chains = question_answer_calculation()
from chains_lab import booway_cypher_chain
# 加载数据
with open('./data/data.json', 'r', encoding='utf-8') as file:
with open(info_data_json, 'r', encoding='utf-8') as file:
data = json.load(file)
print("📥 请输入查询内容,输入 'exit' 可退出程序。\n")
@@ -91,11 +90,16 @@ while True:
print(ques)
retriever_info = []
for idx, i in enumerate(ques):
response = booway_cypher_chain.invoke(i)
temp = response.get("result")
retriever_info.append(temp)
for i in ques[:12]:
try:
response = booway_cypher_chain.invoke(i)
temp = response.get("result")
# todo: 重复筛选策略
retriever_info.append(temp)
except Exception as e:
print(f"处理问题时出错: {e}")
retriever_info.append(None)
retriever_keywords = ques_info[0]
calculation = ques_info[-1]
+122
View File
@@ -0,0 +1,122 @@
import streamlit as st
import json
from chains_lab import Problem_rewrite, booway_cypher_chain
from chains_lab import question_answer, question_answer_calculation
from vector_lab import intersection_of_three_lists
from utils import find_target_item, find_target_items, pre_mapping, pre_mapping2
from utils import extract_concrete_info, extract_query_prefix_list, split_chinese_bracketed_phrases
from extraction_info import info_data_json
# 初始化 chains
problem_rewrite = Problem_rewrite()
qa_chains = question_answer()
calculation_chains = question_answer_calculation()
# 加载数据
with open(info_data_json, 'r', encoding='utf-8') as file:
data = json.load(file)
# Streamlit 页面设置
st.set_page_config(page_title="图谱问答系统", layout="wide")
st.title("💬 图谱问答系统")
st.markdown("请在下方输入你的查询问题:")
# 输入框
# input_str1 = "杆塔总基数是多少?"
# input_str2 = "单回路长度是多少?"
# input_str3 = "计算一下角钢塔的塔材装材费"
# input_str4 = "计算一下土石方总量"
# input_str5 = "板式塔基的各类基础数量占总塔基数比例是多少?"
# input_str6 = "基础混凝土总量是多少"
# input_str7 = "计算一下本体工程机械费"
# input_str8 = "项目建设技术服务费合计"
input_str = st.text_input("🔍 输入问题", placeholder="例如:计算一下角钢塔的塔材装材费?")
# 执行按钮
if st.button("🚀 开始查询") and input_str.strip():
try:
results = intersection_of_three_lists(input_str)
if not results:
st.warning("⚠️ 无法从向量中获取候选项。")
else:
retriever = results[0]
st.markdown(f"➡️ **匹配向量检索结果**`{retriever}`")
# 重写问题,提取关键词
keywords = problem_rewrite.invoke({
"query": input_str,
"retriever": retriever
})
st.markdown(f"🧠 **提取关键词** `{keywords}`")
# 预映射图数据库结构
input_neo4j = pre_mapping(keywords, data)
st.markdown("📊 **图谱相关知识输出**")
st.code(input_neo4j, language='json')
if isinstance(input_neo4j, str):
response = booway_cypher_chain.invoke(input_neo4j)
cypher_query = response.get("intermediate_steps")[0]
cypher_result = response.get("result")
# print(cypher_query)
# 可视化 Cypher 查询语句
with st.expander("🔐 Cypher 查询语句"):
st.code(cypher_query['query'], language='cypher')
# 问答环节
final_result = qa_chains.invoke({
"query": input_str,
"retriever_keywords": keywords,
"retriever_info": cypher_result
})
st.success("📊 **图谱最终检索输出:**")
st.write(final_result)
elif isinstance(input_neo4j, list):
ques = extract_query_prefix_list(input_neo4j)
temp = extract_concrete_info(input_neo4j)
ques_info = split_chinese_bracketed_phrases(temp[0])
st.markdown("🔎 **问题拆分结果:**")
st.write(ques)
retriever_info = []
cypher_info = []
for i in ques[:12]:
try:
response = booway_cypher_chain.invoke(i)
retriever_info.append(response.get("result"))
cypher_info.append(response.get("intermediate_steps")[0])
except Exception as e:
st.error(f"处理问题时出错:{e}")
retriever_info.append(None)
retriever_keywords = ques_info[0]
calculation = ques_info[-1]
with st.expander("🔐 Cypher 查询语句"):
for idx, cypher in enumerate(cypher_info):
st.code(cypher['query'], language='cypher')
st.markdown("🧮 **表达式输出:**")
st.write(calculation)
final_result = calculation_chains.invoke({
"query": input_neo4j,
"retriever_keywords": retriever_keywords,
"calculation": calculation,
"retriever_info": retriever_info
})
st.success("📊 **图谱最终检索输出:**")
st.write(final_result)
except Exception as e:
st.error(f"❌ 发生错误:{e}")
# streamlit run streamlit_ceshi.py --server.port 2336
+5 -3
View File
@@ -7,6 +7,8 @@ import requests
import httpx
import logging
from extraction_info import info_data_txt, info_faiss_archived
class SiliconFlowEmbeddings(Embeddings):
"""SiliconFlow嵌入模型封装"""
def __init__(self, api_key: str, model: str = "bge-m3"):
@@ -39,15 +41,15 @@ class SiliconFlowEmbeddings(Embeddings):
# embeddings = Embedding(url="http://10.1.16.39:9995/v1", api_key="xxx", model_name="bge-m3")
embeddings = SiliconFlowEmbeddings(api_key="xxx")
with open("./data/data.txt", 'r', encoding='utf-8') as file:
with open(info_data_txt, 'r', encoding='utf-8') as file:
txt_list = [line.strip() for line in file]
# embedding_path = "/data/Z_LLM_data/Embed_data/bge-m3"
# embeddings = HuggingFaceEmbeddings(model_name=embedding_path)
faiss_archived = "./data/faiss_data/data"
# faiss_archived = "./data/faiss_data/data"
vectorstore_txt_faiss = FAISS.from_texts(txt_list, embeddings)
vectorstore_txt_faiss.save_local(faiss_archived)
vectorstore_txt_faiss.save_local(info_faiss_archived)
retriever_txt_faiss1 = vectorstore_txt_faiss.as_retriever(search_kwargs={"k":3})
retriever_txt_faiss2 = vectorstore_txt_faiss.as_retriever(