From 55d2fbe5a55bc678f0c94e8472b66f14ec8c5d70 Mon Sep 17 00:00:00 2001 From: Zdao032 <1546732625@qq.com> Date: Mon, 31 Mar 2025 15:58:55 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A4=9A=E8=BD=AEdemo?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- streamlit_main.py | 268 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 268 insertions(+) create mode 100644 streamlit_main.py diff --git a/streamlit_main.py b/streamlit_main.py new file mode 100644 index 0000000..bbb55fb --- /dev/null +++ b/streamlit_main.py @@ -0,0 +1,268 @@ +""" +=================================== +@Auther:WenZ +@Company: BooWay +@project:booway_dm +=================================== +""" +import streamlit as st +from utils import judge_define_suffix, match_suffix, retrieve_relevant_software +from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, \ + ChatRecord +from utils import stop_word_processing +import spacy +import zh_core_web_sm, zh_core_web_md, zh_core_web_lg, zh_core_web_trf +from utils import get_keywords, get_keywords_v2, get_keywords_v3 +from vector_load import interface_search + +# nlp_sm = zh_core_web_sm.load() +# nlp_md = zh_core_web_md.load() +# nlp_lg = zh_core_web_lg.load() +nlp_trf = zh_core_web_trf.load() + +polite_words = {"你好", "您好", "请", "请问", "谢谢", "不客气", "麻烦", "打扰", "拜托", "辛苦", "劳驾"} + +from chains_ceshi import suffix_answers +from chains_ceshi import Vertical_classification +from chains_ceshi import intention_judge +from chains_ceshi import domain_judge +from chains_ceshi import judge_5W2H + +from chains_rewrite import software_name_rewrite +from chains_rewrite import query_function_rewrite +from chains_rewrite import operation_guidance_rewrite +from chains_rewrite import troubleshooting_rewrite +from chains_rewrite import access_rewrite + +from kg_management import retriever_txt_faiss1 +from kg_management import retriever_txt_faiss2 +from kg_management import retriever_txt_faiss3 +from kg_management import retriever_txt_faiss4 +from kg_management import retriever_txt_faiss5 +from kg_management import retriever_txt_faiss6 +from kg_management import retriever_txt_faiss7 +from kg_management import retriever_txt_faiss8 +from kg_management import retriever_txt_faiss9 + +from kg_management import input_index_csv_path +from kg_management import xizang_input_csv_path +from kg_management import cuceng_input_csv_path +from kg_management import jigai_input_csv_path + +from kg_management import process_domain_category + +domain_mapping = { + '西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path), + '新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path), + '技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path) +} + +chain_suffix_answers = suffix_answers() # 后缀名问题处理 +chain_vertical = Vertical_classification() # 垂直/开放分类 +chain_intention = intention_judge() # 意图分类 +chain_domain = domain_judge() # 领域分类 +chain_5W2H = judge_5W2H() # 5W2H分类 + +chains_name_rewrite = software_name_rewrite() # 问题改写:软件名改写 +chain_function_rewrite = query_function_rewrite() # 问题改写:软件功能查询 +chain_guidance_rewrite = operation_guidance_rewrite() # 问题改写:软件操作指导 +chain_troubleshooting_rewrite = troubleshooting_rewrite() # 问题改写:软件故障排查类 +chain_access_rewrite = access_rewrite() # 问题改写:软件下载与安装 + +from chains_rewrite import to_normal_rewrite +from chains_rewrite import retrieval_rewrite + +chain_normal_rewrite = to_normal_rewrite() +chain_retrieval_rewrite = retrieval_rewrite() + +# 同义词替换 +from utils import normalize_text + +synonym_dict = { + "西藏造价软件Z1": ["西藏造价Z1", "Z1软件", "西藏Z1", "西藏工程造价Z1", "西藏造价通Z1", "Z1", "西藏", + "西藏造价z1", "z1软件", "西藏z1", "西藏工程造价z1", "西藏造价通z1", "z1", ], + "新型储能计价通C1": ["储能计价通C1", "C1储能计价", "新型储能C1", "储能计价C1", "新型储能计价通", "C1", "储能", + "储能计价通c1", "c1储能计价", "新型储能c1", "储能计价c1", "新型储能计价通", "c1"], + "技改检修计价通T1": ["技改检修T1", "T1检修计价", "技改计价通T1", "技改检修计价通", "T1技改检修", "T1", "技改", + "技改检修t1", "t1检修计价", "技改计价通t1", "技改检修计价通", "t1技改检修", "t1"], + "费率": ["费费率"], + "下载": ["获取", "安装", "下载下来", "装上"] +} + +import random + +reponse_prompts = [ + "🤔 Booway软件助手:请问您指的是哪个软件?", + "🤔 Booway软件助手:请提供软件名称,以便更好地帮助您。", + "🤔 Booway软件助手:请问您使用的是什么软件?", + "🤔 Booway软件助手:请告诉我您要查询的软件名称。", + "🤔 Booway软件助手:请问是哪款软件?" +] + +import streamlit as st +import random + +st.set_page_config(page_title="Booway 助手", layout="wide") + +# st.title("🤖 Booway 助手") +# 助手简介 +st.markdown(""" +# 🤖 Booway 软件助手 + +欢迎使用 **booway 软件助手**,这是一个用于协助用户进行电力造价软件相关问题咨询的智能系统。 + +**目前可咨询软件为:** +- 西藏造价软件Z1 +- 新型储能计价通C1 +- 技改检修计价通T1 +- 后缀名文件咨询 + +**使用方法:** +直接在下方输入你的问题,例如: +- 你好,想问下储能的C1那个软件。初设的基本预备费费率想调整一下,但是没有找到能调整的地方 +- 如何把西藏老定额工程升级成西藏Z1的新定额工程 +- 储能软件勾选了卸车,总价不变呢 +- bjgx用什么软件打开的? +- 设备运杂费率怎么设置 (多轮测试) +- 你好,初设的基本预备费费率想调整一下,但是没有找到能调整的地方(多轮测试) + +**注意:多轮对话** +- ~~目前多轮对话中 当机器人询问用户什么软件,则必须是以上软件名字~~ + +如果你输入的是闲聊内容,系统将提示仅内测用户可用。 +""") + + +input_str = st.text_input("🦉 用户:", "") + +if input_str: + if judge_define_suffix(input_str): + nlu_info = NLUInfo(vertical_category="软件咨询") + nlu_info.intent_category = "查询功能" + nlu_info.domain_category = "后缀名查询" + + suffix_name = match_suffix(input_str) + nlu_info.retrieve_keywords = suffix_name + suffix_to_software = retrieve_relevant_software(suffix_name) + + if isinstance(suffix_to_software, int): + st.info("Booway 助手:未查到相关知识") + + elif isinstance(suffix_to_software, str): + query_rewrite = f"{suffix_name}是什么文件?用什么软件打开?" + nlu_info.rewrite = query_rewrite + query_kg = suffix_to_software + result = chain_suffix_answers.invoke({"query": input_str, "kg": query_kg}) + st.subheader("识别出的NLU信息") + st.json({"垂直分类": nlu_info.vertical_category, + "意图分类": nlu_info.intent_category, + "领域分类": nlu_info.domain_category, + "问题分类": nlu_info.question_type, + "检索语义": nlu_info.retrieve_keywords, + "改写结果": nlu_info.rewrite, + "检索回答": result}) + + elif isinstance(suffix_to_software, list): + suffix_to_software_str = '\n'.join(suffix_to_software) + query_rewrite = f"{suffix_name}是什么文件?用什么软件打开?" + nlu_info.rewrite = query_rewrite + query_kg = suffix_to_software_str + result = chain_suffix_answers.invoke({"query": input_str, "kg": query_kg}) + st.subheader("识别出的NLU信息") + st.json({"垂直分类": nlu_info.vertical_category, + "意图分类": nlu_info.intent_category, + "领域分类": nlu_info.domain_category, + "问题分类": nlu_info.question_type, + "检索语义": nlu_info.retrieve_keywords, + "改写结果": nlu_info.rewrite, + "检索回答": result}) + else: + # 多轮对话处理 + # 第一步:预处理输入 + input_str_stoped = stop_word_processing(input_str, nlp_trf, polite_words) + input_str_syn = normalize_text(input_str_stoped, synonym_dict) + + # 第二步:调用分类链判断领域 + vertical_category = chain_domain.invoke(input_str_syn) + + # 第三步:若为未知领域,尝试引导补充软件名 + if vertical_category == "未知": + # 初始化状态变量 + if "mt_input_done" not in st.session_state: + st.session_state.mt_input_done = False + if "mt_input_value" not in st.session_state: + st.session_state.mt_input_value = "" + if "mt_prompt" not in st.session_state: + st.session_state.mt_prompt = random.choice(reponse_prompts) + + # 尚未完成补充输入 + if not st.session_state.mt_input_done: + mt_conversation = st.text_input( + f"{st.session_state.mt_prompt}", key="mt_input" + ) + + if mt_conversation: + st.session_state.mt_input_value = mt_conversation + st.session_state.mt_input_done = True + st.rerun() + else: + st.stop() # 等待输入 + + # 完成补充输入,重新获取 vertical_category + mt_input_str = normalize_text(st.session_state.mt_input_value, synonym_dict) + vertical_category = chain_domain.invoke(mt_input_str) + + # 第四步:NLU构建 + nlu_info = NLUInfo(vertical_category="软件咨询") + nlu_info.domain_category = vertical_category + + # 第五步:改写问题并提取关键词 + input_str_name_rewrite = chains_name_rewrite.invoke({ + "query": input_str, + "software_name": nlu_info.domain_category + }) + input_str_rewrite = chain_normal_rewrite.invoke(input_str_name_rewrite) + + temp_retriever = get_keywords_v2(input_str_rewrite) + nlu_info.question_type = chain_5W2H.invoke(input_str_rewrite) + nlu_info.intent_category = chain_intention.invoke(input_str) + + # 第六步:调用检索器并重写查询 + retrievers, input_csv_path = ( + domain_mapping[nlu_info.domain_category][:3], + domain_mapping[nlu_info.domain_category][3] + ) + + index_keywords = interface_search(temp_retriever, *retrievers) + st.info(f"提取关键词:{index_keywords}") + + nlu_info.rewrite = chain_retrieval_rewrite.invoke({ + "query": input_str_rewrite, + "question_type": nlu_info.question_type, + "intention_type": nlu_info.intent_category, + "keywords": index_keywords + }) + + nlu_info.retrieve_keywords = get_keywords_v3(nlu_info.rewrite) + + # 第七步:展示识别结果 + st.subheader("识别出的NLU信息") + st.json({"垂直分类":nlu_info.vertical_category, + "意图分类":nlu_info.intent_category, + "领域分类":nlu_info.domain_category, + "问题分类":nlu_info.question_type, + "检索语义":nlu_info.retrieve_keywords, + "改写结果":nlu_info.rewrite}) + + for key in ["mt_input_done", "mt_input_value", "mt_prompt", "mt_input"]: + if key in st.session_state: + del st.session_state[key] + + + + + # streamlit run streamlit_main.py --server.port 2335 + + + +