From 517691c2d60c1133aefaba45558fcade3d04e4a3 Mon Sep 17 00:00:00 2001 From: zoujiwen Date: Wed, 18 Jun 2025 16:02:47 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20kg=5Flab=5F6.13?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 6.18 更新数据配置路径统一,和前端demo --- kg_lab_6.13/chains_lab.py | 8 ++- kg_lab_6.13/extraction_info.py | 35 ++++++++++ kg_lab_6.13/main2.py | 32 +++++---- kg_lab_6.13/streamlit_ceshi.py | 122 +++++++++++++++++++++++++++++++++ kg_lab_6.13/vector_lab.py | 8 ++- 5 files changed, 185 insertions(+), 20 deletions(-) create mode 100644 kg_lab_6.13/streamlit_ceshi.py diff --git a/kg_lab_6.13/chains_lab.py b/kg_lab_6.13/chains_lab.py index 18a6ad9..de000e3 100644 --- a/kg_lab_6.13/chains_lab.py +++ b/kg_lab_6.13/chains_lab.py @@ -4,6 +4,8 @@ from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts.prompt import PromptTemplate from langchain_core.output_parsers import JsonOutputParser +from extraction_info import neo4j_url, neo4j_username, neo4j_password + qwen_llm = ChatOpenAI( openai_api_base="https://api.siliconflow.cn/v1", model_name="Qwen/Qwen2.5-72B-Instruct", @@ -111,9 +113,9 @@ from langchain_community.graphs import Neo4jGraph graph = Neo4jGraph( - url="bolt://172.20.0.145:7687", - username="neo4j", - password="password", + url = neo4j_url, + username = neo4j_username, + password = neo4j_password, ) graph.refresh_schema() diff --git a/kg_lab_6.13/extraction_info.py b/kg_lab_6.13/extraction_info.py index e69de29..6b3e1f6 100644 --- a/kg_lab_6.13/extraction_info.py +++ b/kg_lab_6.13/extraction_info.py @@ -0,0 +1,35 @@ +# import os + +# BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + +# info_data_json = os.path.join(BASE_DIR, 'data/data.json') +# info_data_txt = os.path.join(BASE_DIR, 'data/data.txt') +# info_faiss_archived = os.path.join(BASE_DIR, 'data/faiss_data/data') + +# # Neo4j +# neo4j_url = "bolt://172.20.0.145:7687" +# neo4j_username = "neo4j" +# neo4j_password = "password" + + +import os +import json + +# 获取当前目录 +BASE_DIR = os.path.dirname(os.path.abspath(__file__)) + +# 读取 config.json 配置文件 +config_path = os.path.join(BASE_DIR, 'config.json') +with open(config_path, 'r', encoding='utf-8') as f: + config = json.load(f) + +# 解析路径配置 +info_data_json = os.path.join(BASE_DIR, config["info_data_json"]) +info_data_txt = os.path.join(BASE_DIR, config["info_data_txt"]) +info_faiss_archived = os.path.join(BASE_DIR, config["info_faiss_archived"]) + +# 解析 Neo4j 配置 +neo4j_url = config["neo4j"]["url"] +neo4j_username = config["neo4j"]["username"] +neo4j_password = config["neo4j"]["password"] + diff --git a/kg_lab_6.13/main2.py b/kg_lab_6.13/main2.py index 0b54070..3234d4d 100644 --- a/kg_lab_6.13/main2.py +++ b/kg_lab_6.13/main2.py @@ -1,8 +1,13 @@ from chains_lab import Problem_rewrite +from chains_lab import booway_cypher_chain +from chains_lab import question_answer, question_answer_calculation from vector_lab import intersection_of_three_lists from utils import find_target_item, find_target_items, pre_mapping, pre_mapping2 +from utils import extract_concrete_info, extract_query_prefix_list, split_chinese_bracketed_phrases import json +from extraction_info import info_data_json + # 样例 # input_str1 = "杆塔总基数是多少?" # input_str2 = "单回路长度是多少?" @@ -13,21 +18,15 @@ import json # input_str7 = "计算一下本体工程机械费" # input_str8 = "项目建设技术服务费合计" -# 初始化 +# 初始化chians problem_rewrite = Problem_rewrite() - -from utils import extract_concrete_info, extract_query_prefix_list, split_chinese_bracketed_phrases - -from chains_lab import question_answer, question_answer_calculation - qa_chains = question_answer() - calculation_chains = question_answer_calculation() -from chains_lab import booway_cypher_chain + # 加载数据 -with open('./data/data.json', 'r', encoding='utf-8') as file: +with open(info_data_json, 'r', encoding='utf-8') as file: data = json.load(file) print("📥 请输入查询内容,输入 'exit' 可退出程序。\n") @@ -91,11 +90,16 @@ while True: print(ques) retriever_info = [] - for idx, i in enumerate(ques): - response = booway_cypher_chain.invoke(i) - temp = response.get("result") - retriever_info.append(temp) - + for i in ques[:12]: + try: + response = booway_cypher_chain.invoke(i) + temp = response.get("result") + # todo: 重复筛选策略 + retriever_info.append(temp) + except Exception as e: + print(f"处理问题时出错: {e}") + retriever_info.append(None) + retriever_keywords = ques_info[0] calculation = ques_info[-1] diff --git a/kg_lab_6.13/streamlit_ceshi.py b/kg_lab_6.13/streamlit_ceshi.py new file mode 100644 index 0000000..e5d8487 --- /dev/null +++ b/kg_lab_6.13/streamlit_ceshi.py @@ -0,0 +1,122 @@ +import streamlit as st +import json + +from chains_lab import Problem_rewrite, booway_cypher_chain +from chains_lab import question_answer, question_answer_calculation +from vector_lab import intersection_of_three_lists +from utils import find_target_item, find_target_items, pre_mapping, pre_mapping2 +from utils import extract_concrete_info, extract_query_prefix_list, split_chinese_bracketed_phrases +from extraction_info import info_data_json + +# 初始化 chains +problem_rewrite = Problem_rewrite() +qa_chains = question_answer() +calculation_chains = question_answer_calculation() + +# 加载数据 +with open(info_data_json, 'r', encoding='utf-8') as file: + data = json.load(file) + +# Streamlit 页面设置 +st.set_page_config(page_title="图谱问答系统", layout="wide") +st.title("💬 图谱问答系统") +st.markdown("请在下方输入你的查询问题:") + +# 输入框 +# input_str1 = "杆塔总基数是多少?" +# input_str2 = "单回路长度是多少?" +# input_str3 = "计算一下角钢塔的塔材装材费" +# input_str4 = "计算一下土石方总量" +# input_str5 = "板式塔基的各类基础数量占总塔基数比例是多少?" +# input_str6 = "基础混凝土总量是多少" +# input_str7 = "计算一下本体工程机械费" +# input_str8 = "项目建设技术服务费合计" +input_str = st.text_input("🔍 输入问题", placeholder="例如:计算一下角钢塔的塔材装材费?") + +# 执行按钮 +if st.button("🚀 开始查询") and input_str.strip(): + try: + results = intersection_of_three_lists(input_str) + if not results: + st.warning("⚠️ 无法从向量中获取候选项。") + else: + retriever = results[0] + st.markdown(f"➡️ **匹配向量检索结果**:`{retriever}`") + + # 重写问题,提取关键词 + keywords = problem_rewrite.invoke({ + "query": input_str, + "retriever": retriever + }) + + st.markdown(f"🧠 **提取关键词**: `{keywords}`") + + # 预映射图数据库结构 + input_neo4j = pre_mapping(keywords, data) + + st.markdown("📊 **图谱相关知识输出**:") + st.code(input_neo4j, language='json') + + if isinstance(input_neo4j, str): + response = booway_cypher_chain.invoke(input_neo4j) + cypher_query = response.get("intermediate_steps")[0] + cypher_result = response.get("result") + + # print(cypher_query) + # 可视化 Cypher 查询语句 + with st.expander("🔐 Cypher 查询语句"): + st.code(cypher_query['query'], language='cypher') + + # 问答环节 + final_result = qa_chains.invoke({ + "query": input_str, + "retriever_keywords": keywords, + "retriever_info": cypher_result + }) + + st.success("📊 **图谱最终检索输出:**") + st.write(final_result) + + elif isinstance(input_neo4j, list): + ques = extract_query_prefix_list(input_neo4j) + temp = extract_concrete_info(input_neo4j) + ques_info = split_chinese_bracketed_phrases(temp[0]) + + st.markdown("🔎 **问题拆分结果:**") + st.write(ques) + + retriever_info = [] + cypher_info = [] + for i in ques[:12]: + try: + response = booway_cypher_chain.invoke(i) + retriever_info.append(response.get("result")) + cypher_info.append(response.get("intermediate_steps")[0]) + except Exception as e: + st.error(f"处理问题时出错:{e}") + retriever_info.append(None) + + retriever_keywords = ques_info[0] + calculation = ques_info[-1] + + with st.expander("🔐 Cypher 查询语句"): + for idx, cypher in enumerate(cypher_info): + st.code(cypher['query'], language='cypher') + + st.markdown("🧮 **表达式输出:**") + st.write(calculation) + + final_result = calculation_chains.invoke({ + "query": input_neo4j, + "retriever_keywords": retriever_keywords, + "calculation": calculation, + "retriever_info": retriever_info + }) + + st.success("📊 **图谱最终检索输出:**") + st.write(final_result) + + except Exception as e: + st.error(f"❌ 发生错误:{e}") + +# streamlit run streamlit_ceshi.py --server.port 2336 diff --git a/kg_lab_6.13/vector_lab.py b/kg_lab_6.13/vector_lab.py index 79da3cb..1c26f98 100644 --- a/kg_lab_6.13/vector_lab.py +++ b/kg_lab_6.13/vector_lab.py @@ -7,6 +7,8 @@ import requests import httpx import logging +from extraction_info import info_data_txt, info_faiss_archived + class SiliconFlowEmbeddings(Embeddings): """SiliconFlow嵌入模型封装""" def __init__(self, api_key: str, model: str = "bge-m3"): @@ -39,15 +41,15 @@ class SiliconFlowEmbeddings(Embeddings): # embeddings = Embedding(url="http://10.1.16.39:9995/v1", api_key="xxx", model_name="bge-m3") embeddings = SiliconFlowEmbeddings(api_key="xxx") -with open("./data/data.txt", 'r', encoding='utf-8') as file: +with open(info_data_txt, 'r', encoding='utf-8') as file: txt_list = [line.strip() for line in file] # embedding_path = "/data/Z_LLM_data/Embed_data/bge-m3" # embeddings = HuggingFaceEmbeddings(model_name=embedding_path) -faiss_archived = "./data/faiss_data/data" +# faiss_archived = "./data/faiss_data/data" vectorstore_txt_faiss = FAISS.from_texts(txt_list, embeddings) -vectorstore_txt_faiss.save_local(faiss_archived) +vectorstore_txt_faiss.save_local(info_faiss_archived) retriever_txt_faiss1 = vectorstore_txt_faiss.as_retriever(search_kwargs={"k":3}) retriever_txt_faiss2 = vectorstore_txt_faiss.as_retriever(