Files
DM_rewrite_3.31/streamlit_main.py
2025-04-03 17:23:53 +08:00

179 lines
8.0 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
===================================
@AutherWenZ
@Company: BooWay
@projectbooway_dm
===================================
"""
import streamlit as st
from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, ChatRecord
from chains_ceshi import Vertical_classification, small_talk, intention_judge, domain_judge, judge_5W2H, extract_keywords, answer_questions
from kg_management import retriever_txt_faiss1
from kg_management import retriever_txt_faiss2
from kg_management import retriever_txt_faiss3
from kg_management import retriever_txt_faiss4
from kg_management import retriever_txt_faiss5
from kg_management import retriever_txt_faiss6
from kg_management import retriever_txt_faiss7
from kg_management import retriever_txt_faiss8
from kg_management import retriever_txt_faiss9
from kg_management import input_index_csv_path
from kg_management import xizang_input_csv_path
from kg_management import cuceng_input_csv_path
from kg_management import jigai_input_csv_path
from kg_management import process_domain_category
chain_vertical = Vertical_classification() # 垂直/开放分类
chain_intention = intention_judge() # 意图分类
chain_domain = domain_judge() # 领域分类
chain_5w2h = judge_5W2H() # 5W2H分类
chain_keywords = extract_keywords() # 关键语义提取
chains_qa = answer_questions() # 检索回答
# import streamlit as st
#
# chain_vertical = Vertical_classification() # 垂直/开放分类
# chain_intention = intention_judge() # 意图分类
# chain_domain = domain_judge() # 领域分类
# chain_5w2h = judge_5W2H() # 5W2H分类
# chain_keywords = extract_keywords() # 关键语义提取
# chains_qa = answer_questions() # 检索回答
domain_mapping = {
'西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path),
'新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path),
'技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path)
}
import streamlit as st
# 假设这些是你已有的模块
# from your_module import chain_vertical, chain_intention, chain_domain, chain_5w2h, chain_keywords, chains_qa
# from your_module import process_domain_category, NLUInfo, domain_mapping, input_index_csv_path
# 页面配置
st.set_page_config(page_title="booway 软件助手", layout="wide")
# 助手简介
st.markdown("""
# 🤖 booway 软件助手
欢迎使用 **booway 软件助手**,这是一个用于协助用户进行电力造价软件相关问题咨询的智能系统。
**目前可咨询软件为:**
- 西藏造价软件Z1
- 新型储能计价通C1
- 技改检修计价通T1
**使用方法:**
直接在下方输入你的问题,例如:
- “你好,想问下储能的C1那个软件。初设的基本预备费费率想调整一下,但是没有找到能调整的地方”
- “你好,初设的基本预备费费率想调整一下,但是没有找到能调整的地方”(多轮测试)
- “如何把西藏老定额工程升级成西藏Z1的新定额工程”
- “储能软件勾选了卸车,总价不变呢”
- “请问技改检修软件里其他费中的设计费的取费基数和费率,你们设置时肯定要依据吧”
**注意:多轮对话**
- 目前多轮对话中 当机器人询问用户什么软件,则必须是以上软件名字
**注意:垂直分类不对***
- 技改拆除的安全文明施工费费率是多少 分类: 闲聊
- 技改软件拆除的安全文明施工费费率是多少 分类:软件咨询(+ 软件)
- 如:技改项目从老版本升级到新版本,是不是定额也跟着跟新了
- 改为:技改软件项目从老版本升级到新版本,是不是定额也跟着跟新了
如果你输入的是闲聊内容,系统将提示仅内测用户可用。
""")
# 用户输入
user_input = st.text_input("👤 用户:", "")
if user_input:
if user_input.lower() in ["退出", "bye", "再见"]:
st.success("booway软件助手:再见 👋")
else:
vertical_category = chain_vertical.invoke(user_input)
if vertical_category == '闲聊':
st.warning("booway软件助手:闲聊服务只提供给内测用户")
elif vertical_category == '软件咨询':
# 初始化 NLUInfo
nlu_info = NLUInfo(vertical_category="软件咨询")
# 意图识别 & 领域识别
nlu_info.intent_category = chain_intention.invoke(user_input)
domain_info = chain_domain.invoke(user_input)
nlu_info.domain_category = domain_info
# 多轮询问软件名称
if domain_info == '未知':
import random
prompts = [
"🤔 Booway软件助手:请问您指的是哪个软件?",
"🤔 Booway软件助手:请提供软件名称,以便更好地帮助您。",
"🤔 Booway软件助手:请问您使用的是什么软件?",
"🤔 Booway软件助手:请告诉我您要查询的软件名称。",
"🤔 Booway软件助手:请问是哪款软件?"
]
# 随机选择一个提示
random_prompt = random.choice(prompts)
# 生成输入框
software_name = st.text_input(random_prompt, "")
# software_name = st.text_input("🤔 booway软件助手:请问是什么软件?", "")
if software_name and software_name in domain_mapping:
domain_info = software_name
nlu_info.domain_category = software_name
elif software_name:
st.error("booway软件助手:请输入一个有效的软件名称")
if nlu_info.domain_category in domain_mapping:
# 问题类型 + 关键词提取
nlu_info.question_type = chain_5w2h.invoke(user_input)
nlu_info.retrieve_keywords = chain_keywords.invoke({
"software_name": nlu_info.domain_category,
"_5w2h_type": nlu_info.question_type,
"query": user_input
})
# 检索分析
retrievers, input_csv_path = (domain_mapping[nlu_info.domain_category][:3],
domain_mapping[nlu_info.domain_category][3])
kg = process_domain_category(nlu_info, retrievers, input_csv_path, input_index_csv_path)
if kg:
# 检索内容
extracted_index = [item[5:] for item in kg if item.startswith("检索知识")]
st.subheader("📚 检索内容")
st.write(extracted_index)
# 回答内容
response = chains_qa.invoke({"query": user_input, "kg": '\n'.join(kg)})
st.subheader("💬 回答内容")
st.write(response)
# NLU信息
st.subheader("🧠 NLU_info")
st.json({
"vertical_category": nlu_info.vertical_category,
"intent_category": nlu_info.intent_category,
"domain_category": nlu_info.domain_category,
"question_type": nlu_info.question_type,
"retrieve_keywords": nlu_info.retrieve_keywords,
})
else:
st.error("booway软件助手:无相关检索知识,请重新提出问题")
else:
st.warning("booway软件助手:抱歉,我不明白你的问题,请尝试重新表述")
# streamlit run streamlit_main.py --server.port 2335