Files
langchain_KG/kg_lab_6.13/streamlit_ceshi.py
T
zoujiwen 517691c2d6 上传文件至 kg_lab_6.13
6.18 更新数据配置路径统一,和前端demo
2025-06-18 16:02:47 +08:00

123 lines
4.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import streamlit as st
import json
from chains_lab import Problem_rewrite, booway_cypher_chain
from chains_lab import question_answer, question_answer_calculation
from vector_lab import intersection_of_three_lists
from utils import find_target_item, find_target_items, pre_mapping, pre_mapping2
from utils import extract_concrete_info, extract_query_prefix_list, split_chinese_bracketed_phrases
from extraction_info import info_data_json
# 初始化 chains
problem_rewrite = Problem_rewrite()
qa_chains = question_answer()
calculation_chains = question_answer_calculation()
# 加载数据
with open(info_data_json, 'r', encoding='utf-8') as file:
data = json.load(file)
# Streamlit 页面设置
st.set_page_config(page_title="图谱问答系统", layout="wide")
st.title("💬 图谱问答系统")
st.markdown("请在下方输入你的查询问题:")
# 输入框
# input_str1 = "杆塔总基数是多少?"
# input_str2 = "单回路长度是多少?"
# input_str3 = "计算一下角钢塔的塔材装材费"
# input_str4 = "计算一下土石方总量"
# input_str5 = "板式塔基的各类基础数量占总塔基数比例是多少?"
# input_str6 = "基础混凝土总量是多少"
# input_str7 = "计算一下本体工程机械费"
# input_str8 = "项目建设技术服务费合计"
input_str = st.text_input("🔍 输入问题", placeholder="例如:计算一下角钢塔的塔材装材费?")
# 执行按钮
if st.button("🚀 开始查询") and input_str.strip():
try:
results = intersection_of_three_lists(input_str)
if not results:
st.warning("⚠️ 无法从向量中获取候选项。")
else:
retriever = results[0]
st.markdown(f"➡️ **匹配向量检索结果**:`{retriever}`")
# 重写问题,提取关键词
keywords = problem_rewrite.invoke({
"query": input_str,
"retriever": retriever
})
st.markdown(f"🧠 **提取关键词** `{keywords}`")
# 预映射图数据库结构
input_neo4j = pre_mapping(keywords, data)
st.markdown("📊 **图谱相关知识输出**")
st.code(input_neo4j, language='json')
if isinstance(input_neo4j, str):
response = booway_cypher_chain.invoke(input_neo4j)
cypher_query = response.get("intermediate_steps")[0]
cypher_result = response.get("result")
# print(cypher_query)
# 可视化 Cypher 查询语句
with st.expander("🔐 Cypher 查询语句"):
st.code(cypher_query['query'], language='cypher')
# 问答环节
final_result = qa_chains.invoke({
"query": input_str,
"retriever_keywords": keywords,
"retriever_info": cypher_result
})
st.success("📊 **图谱最终检索输出:**")
st.write(final_result)
elif isinstance(input_neo4j, list):
ques = extract_query_prefix_list(input_neo4j)
temp = extract_concrete_info(input_neo4j)
ques_info = split_chinese_bracketed_phrases(temp[0])
st.markdown("🔎 **问题拆分结果:**")
st.write(ques)
retriever_info = []
cypher_info = []
for i in ques[:12]:
try:
response = booway_cypher_chain.invoke(i)
retriever_info.append(response.get("result"))
cypher_info.append(response.get("intermediate_steps")[0])
except Exception as e:
st.error(f"处理问题时出错:{e}")
retriever_info.append(None)
retriever_keywords = ques_info[0]
calculation = ques_info[-1]
with st.expander("🔐 Cypher 查询语句"):
for idx, cypher in enumerate(cypher_info):
st.code(cypher['query'], language='cypher')
st.markdown("🧮 **表达式输出:**")
st.write(calculation)
final_result = calculation_chains.invoke({
"query": input_neo4j,
"retriever_keywords": retriever_keywords,
"calculation": calculation,
"retriever_info": retriever_info
})
st.success("📊 **图谱最终检索输出:**")
st.write(final_result)
except Exception as e:
st.error(f"❌ 发生错误:{e}")
# streamlit run streamlit_ceshi.py --server.port 2336