Files
DM_rewrite_3.31/streamlit_main.py
T
2025-03-31 16:00:57 +08:00

269 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
===================================
@AutherWenZ
@Company: BooWay
@projectbooway_dm
===================================
"""
import streamlit as st
from utils import judge_define_suffix, match_suffix, retrieve_relevant_software
from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, \
ChatRecord
from utils import stop_word_processing
import spacy
import zh_core_web_sm, zh_core_web_md, zh_core_web_lg, zh_core_web_trf
from utils import get_keywords, get_keywords_v2, get_keywords_v3
from vector_load import interface_search
# nlp_sm = zh_core_web_sm.load()
# nlp_md = zh_core_web_md.load()
# nlp_lg = zh_core_web_lg.load()
nlp_trf = zh_core_web_trf.load()
polite_words = {"你好", "您好", "请", "请问", "谢谢", "不客气", "麻烦", "打扰", "拜托", "辛苦", "劳驾"}
from chains_ceshi import suffix_answers
from chains_ceshi import Vertical_classification
from chains_ceshi import intention_judge
from chains_ceshi import domain_judge
from chains_ceshi import judge_5W2H
from chains_rewrite import software_name_rewrite
from chains_rewrite import query_function_rewrite
from chains_rewrite import operation_guidance_rewrite
from chains_rewrite import troubleshooting_rewrite
from chains_rewrite import access_rewrite
from kg_management import retriever_txt_faiss1
from kg_management import retriever_txt_faiss2
from kg_management import retriever_txt_faiss3
from kg_management import retriever_txt_faiss4
from kg_management import retriever_txt_faiss5
from kg_management import retriever_txt_faiss6
from kg_management import retriever_txt_faiss7
from kg_management import retriever_txt_faiss8
from kg_management import retriever_txt_faiss9
from kg_management import input_index_csv_path
from kg_management import xizang_input_csv_path
from kg_management import cuceng_input_csv_path
from kg_management import jigai_input_csv_path
from kg_management import process_domain_category
domain_mapping = {
'西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path),
'新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path),
'技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path)
}
chain_suffix_answers = suffix_answers() # 后缀名问题处理
chain_vertical = Vertical_classification() # 垂直/开放分类
chain_intention = intention_judge() # 意图分类
chain_domain = domain_judge() # 领域分类
chain_5W2H = judge_5W2H() # 5W2H分类
chains_name_rewrite = software_name_rewrite() # 问题改写:软件名改写
chain_function_rewrite = query_function_rewrite() # 问题改写:软件功能查询
chain_guidance_rewrite = operation_guidance_rewrite() # 问题改写:软件操作指导
chain_troubleshooting_rewrite = troubleshooting_rewrite() # 问题改写:软件故障排查类
chain_access_rewrite = access_rewrite() # 问题改写:软件下载与安装
from chains_rewrite import to_normal_rewrite
from chains_rewrite import retrieval_rewrite
chain_normal_rewrite = to_normal_rewrite()
chain_retrieval_rewrite = retrieval_rewrite()
# 同义词替换
from utils import normalize_text
synonym_dict = {
"西藏造价软件Z1": ["西藏造价Z1", "Z1软件", "西藏Z1", "西藏工程造价Z1", "西藏造价通Z1", "Z1", "西藏",
"西藏造价z1", "z1软件", "西藏z1", "西藏工程造价z1", "西藏造价通z1", "z1", ],
"新型储能计价通C1": ["储能计价通C1", "C1储能计价", "新型储能C1", "储能计价C1", "新型储能计价通", "C1", "储能",
"储能计价通c1", "c1储能计价", "新型储能c1", "储能计价c1", "新型储能计价通", "c1"],
"技改检修计价通T1": ["技改检修T1", "T1检修计价", "技改计价通T1", "技改检修计价通", "T1技改检修", "T1", "技改",
"技改检修t1", "t1检修计价", "技改计价通t1", "技改检修计价通", "t1技改检修", "t1"],
"费率": ["费费率"],
"下载": ["获取", "安装", "下载下来", "装上"]
}
import random
reponse_prompts = [
"🤔 Booway软件助手:请问您指的是哪个软件?",
"🤔 Booway软件助手:请提供软件名称,以便更好地帮助您。",
"🤔 Booway软件助手:请问您使用的是什么软件?",
"🤔 Booway软件助手:请告诉我您要查询的软件名称。",
"🤔 Booway软件助手:请问是哪款软件?"
]
import streamlit as st
import random
st.set_page_config(page_title="Booway 助手", layout="wide")
# st.title("🤖 Booway 助手")
# 助手简介
st.markdown("""
# 🤖 Booway 软件助手
欢迎使用 **booway 软件助手**,这是一个用于协助用户进行电力造价软件相关问题咨询的智能系统。
**目前可咨询软件为:**
- 西藏造价软件Z1
- 新型储能计价通C1
- 技改检修计价通T1
- 后缀名文件咨询
**使用方法:**
直接在下方输入你的问题,例如:
- 你好,想问下储能的C1那个软件。初设的基本预备费费率想调整一下,但是没有找到能调整的地方
- 如何把西藏老定额工程升级成西藏Z1的新定额工程
- 储能软件勾选了卸车,总价不变呢
- bjgx用什么软件打开的?
- 设备运杂费率怎么设置 (多轮测试)
- 你好,初设的基本预备费费率想调整一下,但是没有找到能调整的地方(多轮测试)
**注意:多轮对话**
- ~~目前多轮对话中 当机器人询问用户什么软件,则必须是以上软件名字~~
如果你输入的是闲聊内容,系统将提示仅内测用户可用。
""")
input_str = st.text_input("🦉 用户:", "")
if input_str:
if judge_define_suffix(input_str):
nlu_info = NLUInfo(vertical_category="软件咨询")
nlu_info.intent_category = "查询功能"
nlu_info.domain_category = "后缀名查询"
suffix_name = match_suffix(input_str)
nlu_info.retrieve_keywords = suffix_name
suffix_to_software = retrieve_relevant_software(suffix_name)
if isinstance(suffix_to_software, int):
st.info("Booway 助手:未查到相关知识")
elif isinstance(suffix_to_software, str):
query_rewrite = f"{suffix_name}是什么文件?用什么软件打开?"
nlu_info.rewrite = query_rewrite
query_kg = suffix_to_software
result = chain_suffix_answers.invoke({"query": input_str, "kg": query_kg})
st.subheader("识别出的NLU信息")
st.json({"垂直分类": nlu_info.vertical_category,
"意图分类": nlu_info.intent_category,
"领域分类": nlu_info.domain_category,
"问题分类": nlu_info.question_type,
"检索语义": nlu_info.retrieve_keywords,
"改写结果": nlu_info.rewrite,
"检索回答": result})
elif isinstance(suffix_to_software, list):
suffix_to_software_str = '\n'.join(suffix_to_software)
query_rewrite = f"{suffix_name}是什么文件?用什么软件打开?"
nlu_info.rewrite = query_rewrite
query_kg = suffix_to_software_str
result = chain_suffix_answers.invoke({"query": input_str, "kg": query_kg})
st.subheader("识别出的NLU信息")
st.json({"垂直分类": nlu_info.vertical_category,
"意图分类": nlu_info.intent_category,
"领域分类": nlu_info.domain_category,
"问题分类": nlu_info.question_type,
"检索语义": nlu_info.retrieve_keywords,
"改写结果": nlu_info.rewrite,
"检索回答": result})
else:
# 多轮对话处理
# 第一步:预处理输入
input_str_stoped = stop_word_processing(input_str, nlp_trf, polite_words)
input_str_syn = normalize_text(input_str_stoped, synonym_dict)
# 第二步:调用分类链判断领域
vertical_category = chain_domain.invoke(input_str_syn)
# 第三步:若为未知领域,尝试引导补充软件名
if vertical_category == "未知":
# 初始化状态变量
if "mt_input_done" not in st.session_state:
st.session_state.mt_input_done = False
if "mt_input_value" not in st.session_state:
st.session_state.mt_input_value = ""
if "mt_prompt" not in st.session_state:
st.session_state.mt_prompt = random.choice(reponse_prompts)
# 尚未完成补充输入
if not st.session_state.mt_input_done:
mt_conversation = st.text_input(
f"{st.session_state.mt_prompt}", key="mt_input"
)
if mt_conversation:
st.session_state.mt_input_value = mt_conversation
st.session_state.mt_input_done = True
st.rerun()
else:
st.stop() # 等待输入
# 完成补充输入,重新获取 vertical_category
mt_input_str = normalize_text(st.session_state.mt_input_value, synonym_dict)
vertical_category = chain_domain.invoke(mt_input_str)
# 第四步:NLU构建
nlu_info = NLUInfo(vertical_category="软件咨询")
nlu_info.domain_category = vertical_category
# 第五步:改写问题并提取关键词
input_str_name_rewrite = chains_name_rewrite.invoke({
"query": input_str,
"software_name": nlu_info.domain_category
})
input_str_rewrite = chain_normal_rewrite.invoke(input_str_name_rewrite)
temp_retriever = get_keywords_v2(input_str_rewrite)
nlu_info.question_type = chain_5W2H.invoke(input_str_rewrite)
nlu_info.intent_category = chain_intention.invoke(input_str)
# 第六步:调用检索器并重写查询
retrievers, input_csv_path = (
domain_mapping[nlu_info.domain_category][:3],
domain_mapping[nlu_info.domain_category][3]
)
index_keywords = interface_search(temp_retriever, *retrievers)
st.info(f"提取关键词:{index_keywords}")
nlu_info.rewrite = chain_retrieval_rewrite.invoke({
"query": input_str_rewrite,
"question_type": nlu_info.question_type,
"intention_type": nlu_info.intent_category,
"keywords": index_keywords
})
nlu_info.retrieve_keywords = get_keywords_v3(nlu_info.rewrite)
# 第七步:展示识别结果
st.subheader("识别出的NLU信息")
st.json({"垂直分类":nlu_info.vertical_category,
"意图分类":nlu_info.intent_category,
"领域分类":nlu_info.domain_category,
"问题分类":nlu_info.question_type,
"检索语义":nlu_info.retrieve_keywords,
"改写结果":nlu_info.rewrite})
for key in ["mt_input_done", "mt_input_value", "mt_prompt", "mt_input"]:
if key in st.session_state:
del st.session_state[key]
# streamlit run streamlit_main.py --server.port 2335