Files
DM_rewrite_3.31/main.py
T
2025-04-03 17:23:53 +08:00

106 lines
4.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
===================================
@AutherWenZ
@Company: BooWay
@projectbooway_dm
===================================
"""
from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, ChatRecord
from chains_ceshi import Vertical_classification, small_talk, intention_judge, domain_judge, judge_5W2H, extract_keywords, answer_questions, domain_judge, domain_judge_v2
from kg_management import retriever_txt_faiss1
from kg_management import retriever_txt_faiss2
from kg_management import retriever_txt_faiss3
from kg_management import retriever_txt_faiss4
from kg_management import retriever_txt_faiss5
from kg_management import retriever_txt_faiss6
from kg_management import retriever_txt_faiss7
from kg_management import retriever_txt_faiss8
from kg_management import retriever_txt_faiss9
from kg_management import input_index_csv_path
from kg_management import xizang_input_csv_path
from kg_management import cuceng_input_csv_path
from kg_management import jigai_input_csv_path
from kg_management import process_domain_category
chain_vertical = Vertical_classification() # 垂直/开放分类
chain_intention = intention_judge() # 意图分类
chain_domain = domain_judge() # 领域分类
chain_5w2h = judge_5W2H() # 5W2H分类
chain_keywords = extract_keywords() # 关键语义提取
chains_qa = answer_questions() # 检索回答
chain_domain_v2 = domain_judge_v2()
domain_mapping = {
'西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path),
'新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path),
'技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path)
}
while True:
input_str = input("用户:")
if input_str.lower() in ["退出", "bye", "再见"]:
print("booway软件助手:再见")
break
vertical_category = chain_vertical.invoke(input_str)
if vertical_category == '闲聊':
print("booway软件助手:闲聊服务只提供给内测用户")
continue
elif vertical_category == '软件咨询':
# 1. DSTNLUInfo初始化
nlu_info = NLUInfo(vertical_category="软件咨询")
# 2. DST: 更新NLUInfo
intention_info = chain_intention.invoke(input_str)
nlu_info.intent_category = intention_info
domain_info = chain_domain.invoke(input_str)
nlu_info.domain_category = domain_info
# 多轮对话逻辑
while nlu_info.domain_category == '未知':
print("booway软件助手: 请问是什么软件?")
software_name = input("用户: ").strip()
# domain_info = chain_domain_v2.invoke(software_name)
print(software_name)
nlu_info.domain_category = software_name
if nlu_info.domain_category not in domain_mapping:
print("booway软件助手:请输入一个有效的软件名称")
# 多轮对话触发条件:领域信息缺失
# print(f"当前的 domain_info: {domain_info}")
# print(f"当前的 nlu_info.domain_category: {nlu_info.domain_category}")
# print(f"当前的 domain_mapping.keys(): {list(domain_mapping.keys())}")
_5w2h_info = chain_5w2h.invoke(input_str)
nlu_info.question_type = _5w2h_info
keywords_info = chain_keywords.invoke({"software_name": domain_info,
"_5w2h_type": _5w2h_info,
"query": input_str})
nlu_info.retrieve_keywords = keywords_info
# 3. 检索分析
retrievers, input_csv_path = domain_mapping[nlu_info.domain_category][:3], \
domain_mapping[nlu_info.domain_category][3]
kg = process_domain_category(nlu_info, retrievers, input_csv_path, input_index_csv_path)
# print(kg)
if kg != []:
extracted_index = [item[5:] for item in kg if item.startswith("检索知识")]
print(f"系统:检索到的知识为{extracted_index},分析如下:\n")
print(chains_qa.invoke({"query": input_str, "kg": '\n'.join(kg)}))
else:
print("无相关检索知识,请重新提出问题")
else:
print("booway软件助手:抱歉,我不明白你的问题,请尝试重新表述")
continue