517691c2d6
6.18 更新数据配置路径统一,和前端demo
123 lines
4.7 KiB
Python
123 lines
4.7 KiB
Python
import streamlit as st
|
||
import json
|
||
|
||
from chains_lab import Problem_rewrite, booway_cypher_chain
|
||
from chains_lab import question_answer, question_answer_calculation
|
||
from vector_lab import intersection_of_three_lists
|
||
from utils import find_target_item, find_target_items, pre_mapping, pre_mapping2
|
||
from utils import extract_concrete_info, extract_query_prefix_list, split_chinese_bracketed_phrases
|
||
from extraction_info import info_data_json
|
||
|
||
# 初始化 chains
|
||
problem_rewrite = Problem_rewrite()
|
||
qa_chains = question_answer()
|
||
calculation_chains = question_answer_calculation()
|
||
|
||
# 加载数据
|
||
with open(info_data_json, 'r', encoding='utf-8') as file:
|
||
data = json.load(file)
|
||
|
||
# Streamlit 页面设置
|
||
st.set_page_config(page_title="图谱问答系统", layout="wide")
|
||
st.title("💬 图谱问答系统")
|
||
st.markdown("请在下方输入你的查询问题:")
|
||
|
||
# 输入框
|
||
# input_str1 = "杆塔总基数是多少?"
|
||
# input_str2 = "单回路长度是多少?"
|
||
# input_str3 = "计算一下角钢塔的塔材装材费"
|
||
# input_str4 = "计算一下土石方总量"
|
||
# input_str5 = "板式塔基的各类基础数量占总塔基数比例是多少?"
|
||
# input_str6 = "基础混凝土总量是多少"
|
||
# input_str7 = "计算一下本体工程机械费"
|
||
# input_str8 = "项目建设技术服务费合计"
|
||
input_str = st.text_input("🔍 输入问题", placeholder="例如:计算一下角钢塔的塔材装材费?")
|
||
|
||
# 执行按钮
|
||
if st.button("🚀 开始查询") and input_str.strip():
|
||
try:
|
||
results = intersection_of_three_lists(input_str)
|
||
if not results:
|
||
st.warning("⚠️ 无法从向量中获取候选项。")
|
||
else:
|
||
retriever = results[0]
|
||
st.markdown(f"➡️ **匹配向量检索结果**:`{retriever}`")
|
||
|
||
# 重写问题,提取关键词
|
||
keywords = problem_rewrite.invoke({
|
||
"query": input_str,
|
||
"retriever": retriever
|
||
})
|
||
|
||
st.markdown(f"🧠 **提取关键词**: `{keywords}`")
|
||
|
||
# 预映射图数据库结构
|
||
input_neo4j = pre_mapping(keywords, data)
|
||
|
||
st.markdown("📊 **图谱相关知识输出**:")
|
||
st.code(input_neo4j, language='json')
|
||
|
||
if isinstance(input_neo4j, str):
|
||
response = booway_cypher_chain.invoke(input_neo4j)
|
||
cypher_query = response.get("intermediate_steps")[0]
|
||
cypher_result = response.get("result")
|
||
|
||
# print(cypher_query)
|
||
# 可视化 Cypher 查询语句
|
||
with st.expander("🔐 Cypher 查询语句"):
|
||
st.code(cypher_query['query'], language='cypher')
|
||
|
||
# 问答环节
|
||
final_result = qa_chains.invoke({
|
||
"query": input_str,
|
||
"retriever_keywords": keywords,
|
||
"retriever_info": cypher_result
|
||
})
|
||
|
||
st.success("📊 **图谱最终检索输出:**")
|
||
st.write(final_result)
|
||
|
||
elif isinstance(input_neo4j, list):
|
||
ques = extract_query_prefix_list(input_neo4j)
|
||
temp = extract_concrete_info(input_neo4j)
|
||
ques_info = split_chinese_bracketed_phrases(temp[0])
|
||
|
||
st.markdown("🔎 **问题拆分结果:**")
|
||
st.write(ques)
|
||
|
||
retriever_info = []
|
||
cypher_info = []
|
||
for i in ques[:12]:
|
||
try:
|
||
response = booway_cypher_chain.invoke(i)
|
||
retriever_info.append(response.get("result"))
|
||
cypher_info.append(response.get("intermediate_steps")[0])
|
||
except Exception as e:
|
||
st.error(f"处理问题时出错:{e}")
|
||
retriever_info.append(None)
|
||
|
||
retriever_keywords = ques_info[0]
|
||
calculation = ques_info[-1]
|
||
|
||
with st.expander("🔐 Cypher 查询语句"):
|
||
for idx, cypher in enumerate(cypher_info):
|
||
st.code(cypher['query'], language='cypher')
|
||
|
||
st.markdown("🧮 **表达式输出:**")
|
||
st.write(calculation)
|
||
|
||
final_result = calculation_chains.invoke({
|
||
"query": input_neo4j,
|
||
"retriever_keywords": retriever_keywords,
|
||
"calculation": calculation,
|
||
"retriever_info": retriever_info
|
||
})
|
||
|
||
st.success("📊 **图谱最终检索输出:**")
|
||
st.write(final_result)
|
||
|
||
except Exception as e:
|
||
st.error(f"❌ 发生错误:{e}")
|
||
|
||
# streamlit run streamlit_ceshi.py --server.port 2336
|