多轮demo
This commit is contained in:
@@ -0,0 +1,268 @@
|
||||
"""
|
||||
===================================
|
||||
@Auther:WenZ
|
||||
@Company: BooWay
|
||||
@project:booway_dm
|
||||
===================================
|
||||
"""
|
||||
import streamlit as st
|
||||
from utils import judge_define_suffix, match_suffix, retrieve_relevant_software
|
||||
from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, \
|
||||
ChatRecord
|
||||
from utils import stop_word_processing
|
||||
import spacy
|
||||
import zh_core_web_sm, zh_core_web_md, zh_core_web_lg, zh_core_web_trf
|
||||
from utils import get_keywords, get_keywords_v2, get_keywords_v3
|
||||
from vector_load import interface_search
|
||||
|
||||
# nlp_sm = zh_core_web_sm.load()
|
||||
# nlp_md = zh_core_web_md.load()
|
||||
# nlp_lg = zh_core_web_lg.load()
|
||||
nlp_trf = zh_core_web_trf.load()
|
||||
|
||||
polite_words = {"你好", "您好", "请", "请问", "谢谢", "不客气", "麻烦", "打扰", "拜托", "辛苦", "劳驾"}
|
||||
|
||||
from chains_ceshi import suffix_answers
|
||||
from chains_ceshi import Vertical_classification
|
||||
from chains_ceshi import intention_judge
|
||||
from chains_ceshi import domain_judge
|
||||
from chains_ceshi import judge_5W2H
|
||||
|
||||
from chains_rewrite import software_name_rewrite
|
||||
from chains_rewrite import query_function_rewrite
|
||||
from chains_rewrite import operation_guidance_rewrite
|
||||
from chains_rewrite import troubleshooting_rewrite
|
||||
from chains_rewrite import access_rewrite
|
||||
|
||||
from kg_management import retriever_txt_faiss1
|
||||
from kg_management import retriever_txt_faiss2
|
||||
from kg_management import retriever_txt_faiss3
|
||||
from kg_management import retriever_txt_faiss4
|
||||
from kg_management import retriever_txt_faiss5
|
||||
from kg_management import retriever_txt_faiss6
|
||||
from kg_management import retriever_txt_faiss7
|
||||
from kg_management import retriever_txt_faiss8
|
||||
from kg_management import retriever_txt_faiss9
|
||||
|
||||
from kg_management import input_index_csv_path
|
||||
from kg_management import xizang_input_csv_path
|
||||
from kg_management import cuceng_input_csv_path
|
||||
from kg_management import jigai_input_csv_path
|
||||
|
||||
from kg_management import process_domain_category
|
||||
|
||||
domain_mapping = {
|
||||
'西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path),
|
||||
'新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path),
|
||||
'技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path)
|
||||
}
|
||||
|
||||
chain_suffix_answers = suffix_answers() # 后缀名问题处理
|
||||
chain_vertical = Vertical_classification() # 垂直/开放分类
|
||||
chain_intention = intention_judge() # 意图分类
|
||||
chain_domain = domain_judge() # 领域分类
|
||||
chain_5W2H = judge_5W2H() # 5W2H分类
|
||||
|
||||
chains_name_rewrite = software_name_rewrite() # 问题改写:软件名改写
|
||||
chain_function_rewrite = query_function_rewrite() # 问题改写:软件功能查询
|
||||
chain_guidance_rewrite = operation_guidance_rewrite() # 问题改写:软件操作指导
|
||||
chain_troubleshooting_rewrite = troubleshooting_rewrite() # 问题改写:软件故障排查类
|
||||
chain_access_rewrite = access_rewrite() # 问题改写:软件下载与安装
|
||||
|
||||
from chains_rewrite import to_normal_rewrite
|
||||
from chains_rewrite import retrieval_rewrite
|
||||
|
||||
chain_normal_rewrite = to_normal_rewrite()
|
||||
chain_retrieval_rewrite = retrieval_rewrite()
|
||||
|
||||
# 同义词替换
|
||||
from utils import normalize_text
|
||||
|
||||
synonym_dict = {
|
||||
"西藏造价软件Z1": ["西藏造价Z1", "Z1软件", "西藏Z1", "西藏工程造价Z1", "西藏造价通Z1", "Z1", "西藏",
|
||||
"西藏造价z1", "z1软件", "西藏z1", "西藏工程造价z1", "西藏造价通z1", "z1", ],
|
||||
"新型储能计价通C1": ["储能计价通C1", "C1储能计价", "新型储能C1", "储能计价C1", "新型储能计价通", "C1", "储能",
|
||||
"储能计价通c1", "c1储能计价", "新型储能c1", "储能计价c1", "新型储能计价通", "c1"],
|
||||
"技改检修计价通T1": ["技改检修T1", "T1检修计价", "技改计价通T1", "技改检修计价通", "T1技改检修", "T1", "技改",
|
||||
"技改检修t1", "t1检修计价", "技改计价通t1", "技改检修计价通", "t1技改检修", "t1"],
|
||||
"费率": ["费费率"],
|
||||
"下载": ["获取", "安装", "下载下来", "装上"]
|
||||
}
|
||||
|
||||
import random
|
||||
|
||||
reponse_prompts = [
|
||||
"🤔 Booway软件助手:请问您指的是哪个软件?",
|
||||
"🤔 Booway软件助手:请提供软件名称,以便更好地帮助您。",
|
||||
"🤔 Booway软件助手:请问您使用的是什么软件?",
|
||||
"🤔 Booway软件助手:请告诉我您要查询的软件名称。",
|
||||
"🤔 Booway软件助手:请问是哪款软件?"
|
||||
]
|
||||
|
||||
import streamlit as st
|
||||
import random
|
||||
|
||||
st.set_page_config(page_title="Booway 助手", layout="wide")
|
||||
|
||||
# st.title("🤖 Booway 助手")
|
||||
# 助手简介
|
||||
st.markdown("""
|
||||
# 🤖 Booway 软件助手
|
||||
|
||||
欢迎使用 **booway 软件助手**,这是一个用于协助用户进行电力造价软件相关问题咨询的智能系统。
|
||||
|
||||
**目前可咨询软件为:**
|
||||
- 西藏造价软件Z1
|
||||
- 新型储能计价通C1
|
||||
- 技改检修计价通T1
|
||||
- 后缀名文件咨询
|
||||
|
||||
**使用方法:**
|
||||
直接在下方输入你的问题,例如:
|
||||
- 你好,想问下储能的C1那个软件。初设的基本预备费费率想调整一下,但是没有找到能调整的地方
|
||||
- 如何把西藏老定额工程升级成西藏Z1的新定额工程
|
||||
- 储能软件勾选了卸车,总价不变呢
|
||||
- bjgx用什么软件打开的?
|
||||
- 设备运杂费率怎么设置 (多轮测试)
|
||||
- 你好,初设的基本预备费费率想调整一下,但是没有找到能调整的地方(多轮测试)
|
||||
|
||||
**注意:多轮对话**
|
||||
- ~~目前多轮对话中 当机器人询问用户什么软件,则必须是以上软件名字~~
|
||||
|
||||
如果你输入的是闲聊内容,系统将提示仅内测用户可用。
|
||||
""")
|
||||
|
||||
|
||||
input_str = st.text_input("🦉 用户:", "")
|
||||
|
||||
if input_str:
|
||||
if judge_define_suffix(input_str):
|
||||
nlu_info = NLUInfo(vertical_category="软件咨询")
|
||||
nlu_info.intent_category = "查询功能"
|
||||
nlu_info.domain_category = "后缀名查询"
|
||||
|
||||
suffix_name = match_suffix(input_str)
|
||||
nlu_info.retrieve_keywords = suffix_name
|
||||
suffix_to_software = retrieve_relevant_software(suffix_name)
|
||||
|
||||
if isinstance(suffix_to_software, int):
|
||||
st.info("Booway 助手:未查到相关知识")
|
||||
|
||||
elif isinstance(suffix_to_software, str):
|
||||
query_rewrite = f"{suffix_name}是什么文件?用什么软件打开?"
|
||||
nlu_info.rewrite = query_rewrite
|
||||
query_kg = suffix_to_software
|
||||
result = chain_suffix_answers.invoke({"query": input_str, "kg": query_kg})
|
||||
st.subheader("识别出的NLU信息")
|
||||
st.json({"垂直分类": nlu_info.vertical_category,
|
||||
"意图分类": nlu_info.intent_category,
|
||||
"领域分类": nlu_info.domain_category,
|
||||
"问题分类": nlu_info.question_type,
|
||||
"检索语义": nlu_info.retrieve_keywords,
|
||||
"改写结果": nlu_info.rewrite,
|
||||
"检索回答": result})
|
||||
|
||||
elif isinstance(suffix_to_software, list):
|
||||
suffix_to_software_str = '\n'.join(suffix_to_software)
|
||||
query_rewrite = f"{suffix_name}是什么文件?用什么软件打开?"
|
||||
nlu_info.rewrite = query_rewrite
|
||||
query_kg = suffix_to_software_str
|
||||
result = chain_suffix_answers.invoke({"query": input_str, "kg": query_kg})
|
||||
st.subheader("识别出的NLU信息")
|
||||
st.json({"垂直分类": nlu_info.vertical_category,
|
||||
"意图分类": nlu_info.intent_category,
|
||||
"领域分类": nlu_info.domain_category,
|
||||
"问题分类": nlu_info.question_type,
|
||||
"检索语义": nlu_info.retrieve_keywords,
|
||||
"改写结果": nlu_info.rewrite,
|
||||
"检索回答": result})
|
||||
else:
|
||||
# 多轮对话处理
|
||||
# 第一步:预处理输入
|
||||
input_str_stoped = stop_word_processing(input_str, nlp_trf, polite_words)
|
||||
input_str_syn = normalize_text(input_str_stoped, synonym_dict)
|
||||
|
||||
# 第二步:调用分类链判断领域
|
||||
vertical_category = chain_domain.invoke(input_str_syn)
|
||||
|
||||
# 第三步:若为未知领域,尝试引导补充软件名
|
||||
if vertical_category == "未知":
|
||||
# 初始化状态变量
|
||||
if "mt_input_done" not in st.session_state:
|
||||
st.session_state.mt_input_done = False
|
||||
if "mt_input_value" not in st.session_state:
|
||||
st.session_state.mt_input_value = ""
|
||||
if "mt_prompt" not in st.session_state:
|
||||
st.session_state.mt_prompt = random.choice(reponse_prompts)
|
||||
|
||||
# 尚未完成补充输入
|
||||
if not st.session_state.mt_input_done:
|
||||
mt_conversation = st.text_input(
|
||||
f"{st.session_state.mt_prompt}", key="mt_input"
|
||||
)
|
||||
|
||||
if mt_conversation:
|
||||
st.session_state.mt_input_value = mt_conversation
|
||||
st.session_state.mt_input_done = True
|
||||
st.rerun()
|
||||
else:
|
||||
st.stop() # 等待输入
|
||||
|
||||
# 完成补充输入,重新获取 vertical_category
|
||||
mt_input_str = normalize_text(st.session_state.mt_input_value, synonym_dict)
|
||||
vertical_category = chain_domain.invoke(mt_input_str)
|
||||
|
||||
# 第四步:NLU构建
|
||||
nlu_info = NLUInfo(vertical_category="软件咨询")
|
||||
nlu_info.domain_category = vertical_category
|
||||
|
||||
# 第五步:改写问题并提取关键词
|
||||
input_str_name_rewrite = chains_name_rewrite.invoke({
|
||||
"query": input_str,
|
||||
"software_name": nlu_info.domain_category
|
||||
})
|
||||
input_str_rewrite = chain_normal_rewrite.invoke(input_str_name_rewrite)
|
||||
|
||||
temp_retriever = get_keywords_v2(input_str_rewrite)
|
||||
nlu_info.question_type = chain_5W2H.invoke(input_str_rewrite)
|
||||
nlu_info.intent_category = chain_intention.invoke(input_str)
|
||||
|
||||
# 第六步:调用检索器并重写查询
|
||||
retrievers, input_csv_path = (
|
||||
domain_mapping[nlu_info.domain_category][:3],
|
||||
domain_mapping[nlu_info.domain_category][3]
|
||||
)
|
||||
|
||||
index_keywords = interface_search(temp_retriever, *retrievers)
|
||||
st.info(f"提取关键词:{index_keywords}")
|
||||
|
||||
nlu_info.rewrite = chain_retrieval_rewrite.invoke({
|
||||
"query": input_str_rewrite,
|
||||
"question_type": nlu_info.question_type,
|
||||
"intention_type": nlu_info.intent_category,
|
||||
"keywords": index_keywords
|
||||
})
|
||||
|
||||
nlu_info.retrieve_keywords = get_keywords_v3(nlu_info.rewrite)
|
||||
|
||||
# 第七步:展示识别结果
|
||||
st.subheader("识别出的NLU信息")
|
||||
st.json({"垂直分类":nlu_info.vertical_category,
|
||||
"意图分类":nlu_info.intent_category,
|
||||
"领域分类":nlu_info.domain_category,
|
||||
"问题分类":nlu_info.question_type,
|
||||
"检索语义":nlu_info.retrieve_keywords,
|
||||
"改写结果":nlu_info.rewrite})
|
||||
|
||||
for key in ["mt_input_done", "mt_input_value", "mt_prompt", "mt_input"]:
|
||||
if key in st.session_state:
|
||||
del st.session_state[key]
|
||||
|
||||
|
||||
|
||||
|
||||
# streamlit run streamlit_main.py --server.port 2335
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user