From 2cd09e6528736308a1cf296ce7574b63185aa314 Mon Sep 17 00:00:00 2001 From: zoujiwen Date: Fri, 13 Jun 2025 17:53:42 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20kg=5Flab=5F6.13?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 6.13 上传 支持库环境 2. 上传main2.py 支持neo4j库检索 --- kg_lab_6.13/main2.py | 98 ++++++++++++++++++++++++++++++++++++ kg_lab_6.13/requirements.txt | 6 +++ 2 files changed, 104 insertions(+) create mode 100644 kg_lab_6.13/main2.py create mode 100644 kg_lab_6.13/requirements.txt diff --git a/kg_lab_6.13/main2.py b/kg_lab_6.13/main2.py new file mode 100644 index 0000000..9ea00d3 --- /dev/null +++ b/kg_lab_6.13/main2.py @@ -0,0 +1,98 @@ +from chains_lab import Problem_rewrite +from vector_lab import intersection_of_three_lists +from utils import find_target_item, find_target_items, pre_mapping, pre_mapping2 +import json + +# 初始化 +problem_rewrite = Problem_rewrite() + +from utils import extract_concrete_info, extract_query_prefix_list + +from chains_lab import question_answer, question_answer_calculation + +qa_chains = question_answer() + +calculation_chains = question_answer_calculation() + +from chains_lab import booway_cypher_chain + +# 加载数据 +with open('./data/data.json', 'r', encoding='utf-8') as file: + data = json.load(file) + +print("📥 请输入查询内容,输入 'exit' 可退出程序。\n") + +while True: + input_str = input("🔍 输入问题:") + + if input_str.lower() == 'exit': + print("👋 已退出。") + break + + try: + results = intersection_of_three_lists(input_str) + if not results: + print("⚠️ 无法从向量中获取候选项。") + continue + + retriever = results[0] + + print(f"➡️ 匹配向量检索结果:{retriever}") + + # 重写问题,提取关键词 + keywords = problem_rewrite.invoke({ + "query": input_str, + "retriever": retriever + }) + + print(f"🧠 提取关键词:{keywords}") + + # 预映射为图数据库结构 + input_neo4j = pre_mapping2(keywords, data) + + print(f"📊 图谱相关知识输出:\n{input_neo4j}\n") + + question = input_neo4j + input_str = input_str + + if isinstance(question, str): + + response = booway_cypher_chain.invoke(question) + + generated_cypher = response.get("intermediate_steps")[0] + + # print(str(generated_cypher)) + + temp = response.get("result") + + # print(temp) + + finally_result = qa_chains.invoke({"query":input_str, + "retriever_keywords":keywords, + "retriever_info":temp}) + + print(f"📊 图谱最终检索输出:\n{finally_result}\n") + + + elif isinstance(question, list): + ques = extract_query_prefix_list(question) + ques_info = extract_concrete_info(question) + + retriever_info = [] + for i in ques: + response = booway_cypher_chain.invoke(ques) + # generated_cypher = response.get("intermediate_steps")[0] + temp = response.get("result") + retriever_info.append(temp) + + retriever_keywords = ques_info[0] + calculation = ques_info[-1] + + finally_result = calculation_chains.invoke({"query":question, + "retriever_keywords":retriever_keywords, + "calculation":calculation, + "retriever_info":retriever_info}) + print(f"📊 图谱最终检索输出:\n{finally_result}\n") + + except Exception as e: + print(f"❌ 发生错误:{e}") \ No newline at end of file diff --git a/kg_lab_6.13/requirements.txt b/kg_lab_6.13/requirements.txt new file mode 100644 index 0000000..a5d71ad --- /dev/null +++ b/kg_lab_6.13/requirements.txt @@ -0,0 +1,6 @@ +langchain==0.3.19 +langchain_core==0.3.49 +langchain_community==0.3.18 +langchain_huggingface==0.1.2 +faiss-gpu-cu12==1.9.0.post1 +neo4j==5.27.0