首次提交:上传本地文件夹
This commit is contained in:
@@ -0,0 +1,238 @@
|
||||
import gradio as gr
|
||||
from graph.graph_rag import GraphRAG
|
||||
import traceback
|
||||
import time
|
||||
import sys
|
||||
import json
|
||||
|
||||
# 添加意图识别和槽位提取相关导入
|
||||
from manager.manager_intention import ParserIntentJson
|
||||
from utils.utils import extract_names_from_json
|
||||
from chains_lab.unified_chain import UnifiedNLUChain
|
||||
|
||||
# 初始化GraphRAG系统
|
||||
print("正在初始化GraphRAG系统...")
|
||||
rag = GraphRAG(
|
||||
neo4j_uri="bolt://10.1.6.34:7687",
|
||||
neo4j_auth=("neo4j", "password")
|
||||
)
|
||||
print("GraphRAG系统初始化完成")
|
||||
|
||||
# 初始化意图识别和槽位提取系统
|
||||
print("正在初始化意图识别和槽位提取系统...")
|
||||
parser = ParserIntentJson("./data/dynamic_structure/intent_json.json")
|
||||
suffix_file_path = "./data/booway_knowledge_base/keywords_kg/suffix_keywords.json"
|
||||
suffix_fields = extract_names_from_json(suffix_file_path)
|
||||
suffix_fields.extend(['gec5', 'bczc2', 'xzwb2', 'BPQ', 'BPY'])
|
||||
unified_nlu_chain = UnifiedNLUChain(suffix_fields, parser)
|
||||
print("意图识别和槽位提取系统初始化完成")
|
||||
|
||||
def process_query(query):
|
||||
"""处理用户查询并返回结果"""
|
||||
if not query.strip():
|
||||
return "请输入您的问题"
|
||||
|
||||
try:
|
||||
print(f"正在处理查询: {query}")
|
||||
|
||||
# 步骤0: 意图识别和槽位提取
|
||||
print("步骤0: 开始意图识别和槽位提取...")
|
||||
sys.stdout.flush()
|
||||
|
||||
start_time = time.time()
|
||||
nlu_result = unified_nlu_chain.invoke({"query": query})
|
||||
# 直接使用nlu_result中的full_response,它已经是一个字典
|
||||
nlu_response = nlu_result["full_response"]
|
||||
print(f"意图识别和槽位提取完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 检查nlu_response是否已经是字典
|
||||
try:
|
||||
# 如果nlu_response是字符串,则解析为字典
|
||||
if isinstance(nlu_response, str):
|
||||
nlu_data = json.loads(nlu_response)
|
||||
else:
|
||||
# 如果已经是字典,直接使用
|
||||
nlu_data = nlu_response
|
||||
|
||||
print(f"识别到的意图: {nlu_data.get('意图', {})}")
|
||||
|
||||
# 提取功能需求槽位
|
||||
functionality = None
|
||||
slots = {}
|
||||
|
||||
# 从二级意图中提取功能需求
|
||||
if '意图' in nlu_data and '二级意图' in nlu_data['意图']:
|
||||
second_intent = nlu_data['意图']['二级意图']
|
||||
if 'slot_lv2' in second_intent:
|
||||
slot_lv2 = second_intent['slot_lv2']
|
||||
functionality = slot_lv2.get('functionality_required')
|
||||
|
||||
# 收集所有槽位
|
||||
for slot_name, slot_value in slot_lv2.items():
|
||||
if slot_value and slot_value != '未知':
|
||||
slots[slot_name] = slot_value
|
||||
|
||||
print(f"提取的功能需求: {functionality}")
|
||||
print(f"提取的所有槽位: {slots}")
|
||||
|
||||
# 如果没有提取到功能需求,使用原始查询
|
||||
if not functionality:
|
||||
search_query = query
|
||||
print("未提取到功能需求,使用原始查询进行检索")
|
||||
else:
|
||||
search_query = functionality
|
||||
print(f"使用提取的功能需求进行检索: {search_query}")
|
||||
|
||||
# 保存提取的槽位信息
|
||||
extracted_slots = slots
|
||||
except (json.JSONDecodeError, TypeError) as e:
|
||||
print(f"解析NLU结果失败: {str(e)},使用原始查询进行检索")
|
||||
print(f"NLU响应类型: {type(nlu_response)}")
|
||||
search_query = query
|
||||
extracted_slots = {}
|
||||
nlu_data = {"意图": {"一级意图": {"name": "未知"}, "二级意图": {"name": "未知"}}}
|
||||
|
||||
# 添加详细的步骤日志
|
||||
print("步骤1: 开始检索相关信息...")
|
||||
sys.stdout.flush() # 确保日志立即显示
|
||||
|
||||
# 在process_query函数中修改检索部分
|
||||
start_time = time.time()
|
||||
# 将槽位信息传递给检索器
|
||||
retrieved_info = rag.retriever.retrieve(search_query, top_k=5, slots=extracted_slots)
|
||||
print(f"检索完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
print(f"检索到 {len(retrieved_info)} 条相关信息")
|
||||
|
||||
print("步骤2: 开始生成回答...")
|
||||
sys.stdout.flush()
|
||||
|
||||
start_time = time.time()
|
||||
# 将意图和槽位信息也传递给生成器
|
||||
response = rag.generator.generate_response(query, retrieved_info, nlu_data)
|
||||
print(f"生成完成,耗时: {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 构建响应
|
||||
result = {
|
||||
"query": query,
|
||||
"nlu_data": nlu_data,
|
||||
"retrieved_info": retrieved_info,
|
||||
"response": response
|
||||
}
|
||||
|
||||
print("步骤3: 构建最终响应...")
|
||||
final_response = f"### 回答\n{result['response']}\n\n"
|
||||
|
||||
# 添加意图和槽位信息
|
||||
final_response += "### 意图识别和槽位提取\n"
|
||||
if '意图' in nlu_data:
|
||||
if '一级意图' in nlu_data['意图']:
|
||||
final_response += f"一级意图: {nlu_data['意图']['一级意图'].get('name', '未知')}\n"
|
||||
if '二级意图' in nlu_data['意图']:
|
||||
final_response += f"二级意图: {nlu_data['意图']['二级意图'].get('name', '未知')}\n"
|
||||
|
||||
# 添加槽位信息
|
||||
final_response += "槽位:\n"
|
||||
if '二级意图' in nlu_data['意图'] and 'slot_lv2' in nlu_data['意图']['二级意图']:
|
||||
for slot_name, slot_value in nlu_data['意图']['二级意图']['slot_lv2'].items():
|
||||
if slot_value and slot_value != '未知':
|
||||
final_response += f"- {slot_name}: {slot_value}\n"
|
||||
else:
|
||||
final_response += "未识别到意图\n"
|
||||
final_response += "\n"
|
||||
|
||||
# 添加检索到的信息 - 修改部分
|
||||
final_response += "### 检索到的相关信息\n"
|
||||
for i, info in enumerate(result['retrieved_info'], 1):
|
||||
similarity = info.get('similarity', 'N/A')
|
||||
if isinstance(similarity, float):
|
||||
similarity = f"{similarity:.2f}"
|
||||
|
||||
# 获取节点详细信息
|
||||
node = info.get('node', {})
|
||||
|
||||
# 获取节点类型
|
||||
node_type = "未知类型"
|
||||
if "labels" in node and node["labels"]:
|
||||
if isinstance(node["labels"], list) and len(node["labels"]) > 0:
|
||||
node_type = node["labels"][0]
|
||||
elif isinstance(node["labels"], str):
|
||||
node_type = node["labels"]
|
||||
|
||||
# 获取节点名称
|
||||
name = node.get("original_name", "") or node.get("display_name", "") or node.get("name", "未知名称")
|
||||
|
||||
# 获取节点描述
|
||||
description = node.get("描述", "无描述")
|
||||
|
||||
# 获取节点路径(如果有)
|
||||
path = node.get("path", "")
|
||||
if not path and "path_to_root" in node:
|
||||
if isinstance(node["path_to_root"], list):
|
||||
path = " > ".join(node["path_to_root"])
|
||||
else:
|
||||
path = str(node["path_to_root"])
|
||||
|
||||
# 构建显示信息
|
||||
node_info = f"{i}. 【{node_type}】{name} (相似度: {similarity})\n"
|
||||
if path:
|
||||
node_info += f" 路径: {path}\n"
|
||||
if description:
|
||||
# 如果描述太长,截断显示
|
||||
if len(description) > 200:
|
||||
node_info += f" 描述: {description[:200]}...\n"
|
||||
else:
|
||||
node_info += f" 描述: {description}\n"
|
||||
|
||||
# 添加原始文本表示
|
||||
node_info += f" 文本表示: {info['text']}\n"
|
||||
|
||||
final_response += node_info + "\n"
|
||||
|
||||
print("处理完成,返回结果")
|
||||
return final_response
|
||||
except Exception as e:
|
||||
error_msg = f"处理查询时出错: {str(e)}\n\n"
|
||||
error_msg += traceback.format_exc()
|
||||
print(error_msg)
|
||||
return f"### 错误\n```\n{error_msg}\n```"
|
||||
|
||||
# 创建Gradio界面
|
||||
with gr.Blocks(title="知识图谱问答系统") as demo:
|
||||
gr.Markdown("基于知识图谱的检索增强生成(RAG)系统")
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
query_input = gr.Textbox(
|
||||
label="请输入您的问题",
|
||||
placeholder="例如:配网D3软件的工程量计算功能是什么?",
|
||||
lines=1,
|
||||
# 可以增加高度,但保持单行行为
|
||||
scale=2,
|
||||
# 可以设置最小高度
|
||||
min_width=400
|
||||
)
|
||||
# 更新提示文本
|
||||
gr.Markdown("*按回车键直接提交问题*")
|
||||
submit_btn = gr.Button("提交")
|
||||
|
||||
with gr.Column():
|
||||
output = gr.Markdown(label="回答")
|
||||
|
||||
# 设置回车键触发提交
|
||||
query_input.submit(fn=process_query, inputs=query_input, outputs=output)
|
||||
# 保留原有的按钮点击提交
|
||||
submit_btn.click(fn=process_query, inputs=query_input, outputs=output)
|
||||
|
||||
|
||||
# 启动应用
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
print("正在启动Gradio界面...")
|
||||
demo.launch(server_name="0.0.0.0",
|
||||
server_port=7860,
|
||||
share=False,
|
||||
debug=True
|
||||
)
|
||||
finally:
|
||||
print("正在关闭资源...")
|
||||
rag.close()
|
||||
Reference in New Issue
Block a user