From c152fb8714b04ab294550e01870a0c0ab89a3cb6 Mon Sep 17 00:00:00 2001 From: zoujiwen Date: Thu, 3 Apr 2025 17:23:53 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 上传文件 --- fast_api_main_4.2.py | 422 ++++++++++++++++++++++++++++++++++++++++ fast_api_main_4.3.py | 428 +++++++++++++++++++++++++++++++++++++++++ kg_management.py | 106 +++++----- main.py | 210 ++++++++++---------- streamlit_main.py | 446 +++++++++++++++++-------------------------- 5 files changed, 1188 insertions(+), 424 deletions(-) create mode 100644 fast_api_main_4.2.py create mode 100644 fast_api_main_4.3.py diff --git a/fast_api_main_4.2.py b/fast_api_main_4.2.py new file mode 100644 index 0000000..b43869b --- /dev/null +++ b/fast_api_main_4.2.py @@ -0,0 +1,422 @@ +""" +=================================== +@Auther:WenZ +@Company: BooWay +@project:booway_dm +=================================== +""" + +""" +=================================== +@Auther:WenZ +@Company: BooWay +@project:booway_dm +=================================== +""" +from utils import judge_define_suffix, match_suffix, retrieve_relevant_software +from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, \ + ChatRecord +from utils import stop_word_processing +import spacy +import zh_core_web_sm, zh_core_web_md, zh_core_web_lg, zh_core_web_trf +from utils import get_keywords, get_keywords_v2, get_keywords_v3, get_keywords_v4 +from vector_load import interface_search + +# nlp_sm = zh_core_web_sm.load() +# nlp_md = zh_core_web_md.load() +# nlp_lg = zh_core_web_lg.load() +nlp_trf = zh_core_web_trf.load() + +polite_words = {"你好", "您好", "请", "请问", "谢谢", "不客气", "麻烦", "打扰", "拜托", "辛苦", "劳驾"} + +from chains_ceshi import suffix_answers +from chains_ceshi import Vertical_classification +from chains_ceshi import intention_judge +from chains_ceshi import domain_judge +from chains_ceshi import judge_5W2H + +from chains_rewrite import software_name_rewrite +from chains_rewrite import query_function_rewrite +from chains_rewrite import operation_guidance_rewrite +from chains_rewrite import troubleshooting_rewrite +from chains_rewrite import access_rewrite + +from kg_management import retriever_txt_faiss1 +from kg_management import retriever_txt_faiss2 +from kg_management import retriever_txt_faiss3 +from kg_management import retriever_txt_faiss4 +from kg_management import retriever_txt_faiss5 +from kg_management import retriever_txt_faiss6 +from kg_management import retriever_txt_faiss7 +from kg_management import retriever_txt_faiss8 +from kg_management import retriever_txt_faiss9 + +from kg_management import input_index_csv_path +from kg_management import xizang_input_csv_path +from kg_management import cuceng_input_csv_path +from kg_management import jigai_input_csv_path +from kg_management import peiwang_input_path + +from kg_management import process_domain_category + +domain_mapping = { + '西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path), + '新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path), + '技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path), + '博微配网工程计价通D3软件': (retriever_txt_faiss10, retriever_txt_faiss11, retriever_txt_faiss12, peiwang_input_path) +} + +chain_suffix_answers = suffix_answers() # 后缀名问题处理 +chain_vertical = Vertical_classification() # 垂直/开放分类 +chain_intention = intention_judge() # 意图分类 +chain_domain = domain_judge() # 领域分类 +chain_5W2H = judge_5W2H() # 5W2H分类 + +chains_name_rewrite = software_name_rewrite() # 问题改写:软件名改写 +chain_function_rewrite = query_function_rewrite() # 问题改写:软件功能查询 +chain_guidance_rewrite = operation_guidance_rewrite() # 问题改写:软件操作指导 +chain_troubleshooting_rewrite = troubleshooting_rewrite() # 问题改写:软件故障排查类 +chain_access_rewrite = access_rewrite() # 问题改写:软件下载与安装 + +from chains_rewrite import to_normal_rewrite +from chains_rewrite import retrieval_rewrite + +chain_normal_rewrite = to_normal_rewrite() +chain_retrieval_rewrite = retrieval_rewrite() + +from chains_rewrite import full_name_extension + +chain_full_name_extension = full_name_extension() + +from chains_ceshi import domain_judge_v3 + +chain_domain_v3 = domain_judge_v3() + +# 同义词替换 +from utils import normalize_text + +synonym_dict = { + "西藏造价软件Z1": ["西藏造价Z1", "Z1软件", "西藏Z1", "西藏工程造价Z1", "西藏造价通Z1", "Z1", + "西藏造价z1", "z1软件", "西藏z1", "西藏工程造价z1", "西藏造价通z1", "z1",], + "新型储能计价通C1": ["储能计价通C1", "C1储能计价", "新型储能C1", "储能计价C1", "新型储能计价通", "C1", + "储能计价通c1", "c1储能计价", "新型储能c1", "储能计价c1", "新型储能计价通", "c1"], + "技改检修计价通T1": ["技改检修T1", "T1检修计价", "技改计价通T1", "技改检修计价通", "T1技改检修", "T1", + "技改检修t1", "t1检修计价", "技改计价通t1", "技改检修计价通", "t1技改检修", "t1"], + "博微配网工程计价通D3软件": ["博微配网D3", "D3配网计价", "配网计价D3", "配网D3", "D3软件", "博微D3", "D3", + "博微配网工程计价通", "博微配网工程计价通D3", + "博微配网d3", "d3配网计价", "配网计价d3", "配网d3", "d3软件", "博微d3", "d3"] +} + +synonym_dict_2 = { + "费率": ["费费率"], + "下载": ["获取", "安装", "下载下来", "装上"] +} + +from typing import Optional + +dst_data = { + "query": "", # 用户输入 + "vertical_category": "", # 垂直分类 + + # 意图分类槽位结构 + "intent_category_slot": { + "intent_category": "", # 意图分类 + "software_name_required": "", # 软件名称 + "demand_name_required": "", # 需求内容 + "temp_optional": "" # 占位 + }, + + "domain_category": "", # 领域分类 + "question_type": "", # 5W2H分类 + "retrieve_keywords": "", # 检索语义 + "dialogue_status": None, # 对话状态(单轮/多轮 1/2) + + "response": "", # 回复 + # 多轮对话状态 + "multi_dialogue_state": { + "status": None, # 0: rewrite 1: response + "multi_intent_category_slot": { + "multi_intent_category": "", # 意图分类 + "multi_software_name_required": "", # 软件名称 + "multi_demand_name_required": "", # 需求内容 + "multi_temp_optional": "" # 占位 + }, + "multi_domain_category": "", # 领域分类 + "question_type": "", # 5W2H分类 + "retrieve_keywords": "" # 检索语义 + }, + + "rewrite_result": "" # 改写结果 +} + +def get_software_name(input_json): + try: + # 尝试访问 "soft_info" 下的 "软件名称" + software_name = input_json["content"]["soft_info"]["software_name"] + return software_name + except KeyError: + # 如果 "soft_info" 或 "软件名称" 不存在,返回 None + return None + +def chat_record_conversion_demand(data): + + A, B, C = _chat_slice(data) + C = normalize_text(C, synonym_dict) + C = chain_software_judge.invoke(C) + + def _chat_aggregation(A, B, C): + full_dialog = f"user: {A}\nassistant: {B}\nuser: {C}" + return full_dialog + + full_dialog = _chat_aggregation(A, B, C) + + return chain_mutil_text_rewrite.invoke(full_dialog) + +def _chat_slice(data): + history_chat = data["content"].get("history_chat") + input_str = data["input_str"] + + A = history_chat[0].get("user", "") + B = history_chat[1].get("assistant", "") + C = input_str + return A, B, C + +def chat_record_conversion_software(data): + + A, B, C = _chat_slice(data) + A = normalize_text(A, synonym_dict) + A = chain_software_judge.invoke(A) + + def _chat_aggregation(A, B, C): + full_dialog = f"user: {A}\nassistant: {B}\nuser: {C}" + return full_dialog + + full_dialog = _chat_aggregation(A, B, C) + + return chain_mutil_text_rewrite.invoke(full_dialog) + +from chains_rewrite import software_judge +chain_software_judge = software_judge() +from chains_rewrite import mutil_text_rewrite +chain_mutil_text_rewrite = mutil_text_rewrite() + +from chains_ceshi import mutil12 + +chains_mutil12 = mutil12() + + +from fastapi import FastAPI, Request +from pydantic import BaseModel +from typing import List, Union +import uvicorn + +app = FastAPI() + +from typing import Dict, List, Optional, Any + + +@app.post("/analyze") +async def analyze_input(data: Dict[str, Any]): + # 检查 history_chat 是否为空 + if not data["content"].get("history_chat"): # 使用 get 方法来避免 KeyError + dst_data = { + "query": "", # 用户输入 + "vertical_category": "", # 垂直分类 + + # 意图分类槽位结构 + "intent_category_slot": { + "intent_category": "", # 意图分类 + "software_name_required": "", # 软件名称 + "demand_name_required": "", # 需求内容 + "temp_optional": "" # 占位 + }, + + "domain_category": "", # 领域分类 + "question_type": "", # 5W2H分类 + "retrieve_keywords": "", # 检索语义 + "dialogue_status": None, # 对话状态(单轮/多轮 1/2) + + "response": "", # 回复 + # 多轮对话状态 + "multi_dialogue_state": { + "status": None, # 0: rewrite 1: response + "multi_intent_category_slot": { + "multi_intent_category": "", # 意图分类 + "multi_software_name_required": "", # 软件名称 + "multi_demand_name_required": "", # 需求内容 + "multi_temp_optional": "" # 占位 + }, + "multi_domain_category": "", # 领域分类 + "question_type": "", # 5W2H分类 + "retrieve_keywords": "" # 检索语义 + }, + + "rewrite_result": "" # 改写结果 + } + + input_str = data["input_str"] + software_name = get_software_name(data) + + if judge_define_suffix(input_str) == True: + dst_data["query"] = input_str + dst_data["dialogue_status"] = 1 + dst_data["multi_dialogue_state"]["status"] = 1 + dst_data["vertical_category"] = "软件查询" + dst_data["intent_category_slot"]["intent_category"] = "查询功能" + dst_data["domain_category"] = "后缀名查询" + suffix_name = match_suffix(input_str) + dst_data["retrieve_keywords"] = suffix_name + dst_data["question_type"] = "" + suffix_to_software = retrieve_relevant_software(dst_data["retrieve_keywords"]) # 检索出后缀名指定的软件 + if isinstance(suffix_to_software, int): + dst_data["rewrite_result"] = "未查到相关知识" + elif isinstance(suffix_to_software, str): + dst_data["rewrite_result"] = f"{suffix_name}是什么文件?用什么软件打开?" + result = chain_suffix_answers.invoke({"query": input_str, "kg": suffix_to_software}) + dst_data["response"] = result + elif isinstance(suffix_to_software, list): + suffix_to_software_str = '\n'.join(suffix_to_software) + dst_data["rewrite_result"] = f"{suffix_name}是什么文件?用什么软件打开?" + result = chain_suffix_answers.invoke({"query": input_str, "kg": suffix_to_software_str}) + dst_data["response"] = result + else: + # 输入前预处理 + input_str_stoped = stop_word_processing(input_str, nlp_trf, polite_words) # 停用词处理 + input_str_syn = normalize_text(input_str_stoped, synonym_dict_2) # 同义词替换 + dst_data["query"] = input_str_syn + + # 垂直分类 + vertical_category = chain_vertical.invoke(input_str_syn) + print(vertical_category) + if vertical_category == '闲聊': + dst_data["response"] = "闲聊服务只提供给内测用户" + else: + dst_data["vertical_category"] = vertical_category + dst_data["intent_category_slot"]["intent_category"] = chain_intention.invoke(input_str_syn) + dst_data["question_type"] = chain_5W2H.invoke(input_str_syn) + if software_name: + dst_data["dialogue_status"] = 1 + dst_data["domain_category"] = software_name + dst_data["intent_category_slot"]["software_name_required"] = software_name + # 内容扩写(软件全名) + input_str_name_rewrite = chain_full_name_extension.invoke( + {"query": input_str_syn, "soft_name": software_name}) + # print(input_str_name_rewrite) + _, demand = get_keywords_v4(input_str_name_rewrite) + dst_data["intent_category_slot"]["demand_name_required"] = demand + # print(demand) + retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3], + domain_mapping[dst_data["domain_category"]][3]) + index_keywords = interface_search(demand, *retrievers) + dst_data["retrieve_keywords"] = index_keywords + # print(index_keywords) + result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite, + "question_type": dst_data["question_type"], + "intention_type": dst_data["intent_category_slot"][ + "intent_category"], + "keywords": index_keywords}) + dst_data["rewrite_result"] = result_rewrite + else: + dst_data["domain_category"] = chain_domain_v3.invoke(input_str) + if dst_data["domain_category"] != '其他': + dst_data["dialogue_status"] = 1 + dst_data["intent_category_slot"]["software_name_required"] = software_name + # 内容扩写(软件全名) + input_str_name_rewrite = chain_full_name_extension.invoke( + {"query": input_str_stoped, "soft_name": software_name}) + print(input_str_name_rewrite) + _, demand = get_keywords_v4(input_str_name_rewrite) + if demand: + dst_data["intent_category_slot"]["demand_name_required"] = demand + # print(demand) + retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3], + domain_mapping[dst_data["domain_category"]][3]) + index_keywords = interface_search(demand, *retrievers) + dst_data["retrieve_keywords"] = index_keywords + # print(index_keywords) + result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite, + "question_type": dst_data["question_type"], + "intention_type": dst_data["intent_category_slot"][ + "intent_category"], + "keywords": index_keywords}) + dst_data["rewrite_result"] = result_rewrite + else: + dst_data["dialogue_status"] = 2 + dst_data["response"] = f"好的,具体软件是【{software_name}】, 请补充具体的需求" + else: + dst_data["dialogue_status"] = 2 + dst_data["response"] = f"好的,具体需求是【{input_str_syn}】, 但当前缺少具体软件名字,请补充" + return dst_data + else: + # 如果 history_chat 有数据,进行其他处理 + _, B, _ = _chat_slice(data) + if chains_mutil12.invoke(B) == '1': + text = chat_record_conversion_demand(data) + else: + text = chat_record_conversion_software(data) + + st_data = { + "query": "", # 用户输入 + "vertical_category": "", # 垂直分类 + + # 意图分类槽位结构 + "intent_category_slot": { + "intent_category": "", # 意图分类 + "software_name_required": "", # 软件名称 + "demand_name_required": "", # 需求内容 + "temp_optional": "" # 占位 + }, + + "domain_category": "", # 领域分类 + "question_type": "", # 5W2H分类 + "retrieve_keywords": "", # 检索语义 + "dialogue_status": None, # 对话状态(单轮/多轮 1/2) + + "response": "", # 回复 + # 多轮对话状态 + "multi_dialogue_state": { + "status": None, # 0: rewrite 1: response + "multi_intent_category_slot": { + "multi_intent_category": "", # 意图分类 + "multi_software_name_required": "", # 软件名称 + "multi_demand_name_required": "", # 需求内容 + "multi_temp_optional": "" # 占位 + }, + "multi_domain_category": "", # 领域分类 + "question_type": "", # 5W2H分类 + "retrieve_keywords": "" # 检索语义 + }, + + "rewrite_result": "" # 改写结果 + } + + dst_data["query"] = text + dst_data["dialogue_status"] = 1 + vertical_category = chain_vertical.invoke(dst_data["query"]) + if vertical_category == '闲聊': + dst_data["response"] = "闲聊服务只提供给内测用户" + else: + vertical_category = chain_vertical.invoke(dst_data["query"]) + dst_data["intent_category_slot"]["intent_category"] = chain_intention.invoke(dst_data["query"]) + dst_data["question_type"] = chain_5W2H.invoke(dst_data["query"]) + dst_data["domain_category"] = software_name + dst_data["intent_category_slot"]["software_name_required"] = software_name + input_str_name_rewrite = chain_full_name_extension.invoke({"query": input_str_syn, "soft_name": software_name}) + _, demand = get_keywords_v4(input_str_name_rewrite) + dst_data["intent_category_slot"]["demand_name_required"] = demand + retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3], + domain_mapping[dst_data["domain_category"]][3]) + index_keywords = interface_search(demand, *retrievers) + dst_data["retrieve_keywords"] = index_keywords + result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite, + "question_type": dst_data["question_type"], + "intention_type": dst_data["intent_category_slot"][ + "intent_category"], + "keywords": index_keywords}) + dst_data["rewrite_result"] = result_rewrite + return dst_data + +# 可选:本地调试入口 +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=3333) + diff --git a/fast_api_main_4.3.py b/fast_api_main_4.3.py new file mode 100644 index 0000000..4e155e2 --- /dev/null +++ b/fast_api_main_4.3.py @@ -0,0 +1,428 @@ +""" +=================================== +@Auther:WenZ +@Company: BooWay +@project:booway_dm +=================================== +""" + +""" +=================================== +@Auther:WenZ +@Company: BooWay +@project:booway_dm +=================================== +""" +from utils import judge_define_suffix, match_suffix, retrieve_relevant_software +from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, \ + ChatRecord +from utils import stop_word_processing +import spacy +import zh_core_web_sm, zh_core_web_md, zh_core_web_lg, zh_core_web_trf +from utils import get_keywords, get_keywords_v2, get_keywords_v3, get_keywords_v4 +from vector_load import interface_search + +# nlp_sm = zh_core_web_sm.load() +# nlp_md = zh_core_web_md.load() +# nlp_lg = zh_core_web_lg.load() +nlp_trf = zh_core_web_trf.load() + +polite_words = {"你好", "您好", "请", "请问", "谢谢", "不客气", "麻烦", "打扰", "拜托", "辛苦", "劳驾"} + +from chains_ceshi import suffix_answers +from chains_ceshi import Vertical_classification +from chains_ceshi import intention_judge +from chains_ceshi import domain_judge +from chains_ceshi import judge_5W2H + +from chains_rewrite import software_name_rewrite +from chains_rewrite import query_function_rewrite +from chains_rewrite import operation_guidance_rewrite +from chains_rewrite import troubleshooting_rewrite +from chains_rewrite import access_rewrite + +from kg_management import retriever_txt_faiss1 +from kg_management import retriever_txt_faiss2 +from kg_management import retriever_txt_faiss3 +from kg_management import retriever_txt_faiss4 +from kg_management import retriever_txt_faiss5 +from kg_management import retriever_txt_faiss6 +from kg_management import retriever_txt_faiss7 +from kg_management import retriever_txt_faiss8 +from kg_management import retriever_txt_faiss9 +from kg_management import retriever_txt_faiss10 +from kg_management import retriever_txt_faiss11 +from kg_management import retriever_txt_faiss12 + +from kg_management import input_index_csv_path +from kg_management import xizang_input_csv_path +from kg_management import cuceng_input_csv_path +from kg_management import jigai_input_csv_path +from kg_management import peiwang_input_path + +from kg_management import process_domain_category + +domain_mapping = { + '西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path), + '新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path), + '技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path), + '博微配网工程计价通D3软件': (retriever_txt_faiss10, retriever_txt_faiss11, retriever_txt_faiss12, peiwang_input_path) +} + +chain_suffix_answers = suffix_answers() # 后缀名问题处理 +chain_vertical = Vertical_classification() # 垂直/开放分类 +chain_intention = intention_judge() # 意图分类 +chain_domain = domain_judge() # 领域分类 +chain_5W2H = judge_5W2H() # 5W2H分类 + +chains_name_rewrite = software_name_rewrite() # 问题改写:软件名改写 +chain_function_rewrite = query_function_rewrite() # 问题改写:软件功能查询 +chain_guidance_rewrite = operation_guidance_rewrite() # 问题改写:软件操作指导 +chain_troubleshooting_rewrite = troubleshooting_rewrite() # 问题改写:软件故障排查类 +chain_access_rewrite = access_rewrite() # 问题改写:软件下载与安装 + +from chains_rewrite import to_normal_rewrite +from chains_rewrite import retrieval_rewrite + +chain_normal_rewrite = to_normal_rewrite() +chain_retrieval_rewrite = retrieval_rewrite() + +from chains_rewrite import full_name_extension + +chain_full_name_extension = full_name_extension() + +from chains_ceshi import domain_judge_v3 + +chain_domain_v3 = domain_judge_v3() + +# 同义词替换 +from utils import normalize_text + +synonym_dict = { + "西藏造价软件Z1": ["西藏造价Z1", "Z1软件", "西藏Z1", "西藏工程造价Z1", "西藏造价通Z1", "Z1", + "西藏造价z1", "z1软件", "西藏z1", "西藏工程造价z1", "西藏造价通z1", "z1",], + "新型储能计价通C1": ["储能计价通C1", "C1储能计价", "新型储能C1", "储能计价C1", "新型储能计价通", "C1", + "储能计价通c1", "c1储能计价", "新型储能c1", "储能计价c1", "新型储能计价通", "c1"], + "技改检修计价通T1": ["技改检修T1", "T1检修计价", "技改计价通T1", "技改检修计价通", "T1技改检修", "T1", + "技改检修t1", "t1检修计价", "技改计价通t1", "技改检修计价通", "t1技改检修", "t1"], + "博微配网工程计价通D3软件": ["博微配网D3", "D3配网计价", "配网计价D3", "配网D3", "D3软件", "博微D3", "D3", + "博微配网工程计价通", "博微配网工程计价通D3", + "博微配网d3", "d3配网计价", "配网计价d3", "配网d3", "d3软件", "博微d3", "d3"] +} + +synonym_dict_2 = { + "费率": ["费费率"], + "下载": ["获取", "安装", "下载下来", "装上"] +} + +from typing import Optional + +dst_data = { + "query": "", # 用户输入 + "vertical_category": "", # 垂直分类 + + # 意图分类槽位结构 + "intent_category_slot": { + "intent_category": "", # 意图分类 + "software_name_required": "", # 软件名称 + "demand_name_required": "", # 需求内容 + "temp_optional": "" # 占位 + }, + + "domain_category": "", # 领域分类 + "question_type": "", # 5W2H分类 + "retrieve_keywords": "", # 检索语义 + "dialogue_status": None, # 对话状态(单轮/多轮 1/2) + + "response": "", # 回复 + # 多轮对话状态 + "multi_dialogue_state": { + "status": None, # 0: rewrite 1: response + "multi_intent_category_slot": { + "multi_intent_category": "", # 意图分类 + "multi_software_name_required": "", # 软件名称 + "multi_demand_name_required": "", # 需求内容 + "multi_temp_optional": "" # 占位 + }, + "multi_domain_category": "", # 领域分类 + "question_type": "", # 5W2H分类 + "retrieve_keywords": "" # 检索语义 + }, + + "rewrite_result": "" # 改写结果 +} + +def get_software_name(input_json): + try: + # 尝试访问 "soft_info" 下的 "软件名称" + software_name = input_json["content"]["soft_info"]["software_name"] + return software_name + except KeyError: + # 如果 "soft_info" 或 "软件名称" 不存在,返回 None + return None + +def chat_record_conversion_demand(data): + + A, B, C = _chat_slice(data) + C = normalize_text(C, synonym_dict) + C = chain_software_judge.invoke(C) + + def _chat_aggregation(A, B, C): + full_dialog = f"user: {A}\nassistant: {B}\nuser: {C}" + return full_dialog + + full_dialog = _chat_aggregation(A, B, C) + + return chain_mutil_text_rewrite.invoke(full_dialog) + +def _chat_slice(data): + history_chat = data["content"].get("history_chat") + input_str = data["input_str"] + + A = history_chat[0].get("user", "") + B = history_chat[1].get("assistant", "") + C = input_str + return A, B, C + +def chat_record_conversion_software(data): + + A, B, C = _chat_slice(data) + A = normalize_text(A, synonym_dict) + A = chain_software_judge.invoke(A) + + def _chat_aggregation(A, B, C): + full_dialog = f"user: {A}\nassistant: {B}\nuser: {C}" + return full_dialog + + full_dialog = _chat_aggregation(A, B, C) + + return chain_mutil_text_rewrite.invoke(full_dialog) + +from chains_rewrite import software_judge +chain_software_judge = software_judge() +from chains_rewrite import mutil_text_rewrite +chain_mutil_text_rewrite = mutil_text_rewrite() + +from chains_ceshi import mutil12 + +chains_mutil12 = mutil12() + + +from fastapi import FastAPI, Request +from pydantic import BaseModel +from typing import List, Union +import uvicorn + +app = FastAPI() + +from typing import Dict, List, Optional, Any + + +@app.post("/analyze") +async def analyze_input(data: Dict[str, Any]): + # 检查 history_chat 是否为空 + # if not data["content"].get("history_chat"): # 使用 get 方法来避免 KeyError + dst_data = { + "query": "", # 用户输入 + "vertical_category": "", # 垂直分类 + + # 意图分类槽位结构 + "intent_category_slot": { + "intent_category": "", # 意图分类 + "software_name_required": "", # 软件名称 + "demand_name_required": "", # 需求内容 + "temp_optional": "" # 占位 + }, + + "domain_category": "", # 领域分类 + "question_type": "", # 5W2H分类 + "retrieve_keywords": "", # 检索语义 + "dialogue_status": None, # 对话状态(单轮/多轮 1/2) + + "response": "", # 回复 + # 多轮对话状态 + "multi_dialogue_state": { + "status": None, # 0: rewrite 1: response + "multi_intent_category_slot": { + "multi_intent_category": "", # 意图分类 + "multi_software_name_required": "", # 软件名称 + "multi_demand_name_required": "", # 需求内容 + "multi_temp_optional": "" # 占位 + }, + "multi_domain_category": "", # 领域分类 + "question_type": "", # 5W2H分类 + "retrieve_keywords": "" # 检索语义 + }, + + "rewrite_result": "" # 改写结果 + } + + input_str = data["input_str"] + software_name = get_software_name(data) + + if judge_define_suffix(input_str) == True: + dst_data["query"] = input_str + dst_data["dialogue_status"] = 1 + dst_data["multi_dialogue_state"]["status"] = 1 + dst_data["vertical_category"] = "软件查询" + dst_data["intent_category_slot"]["intent_category"] = "查询功能" + dst_data["domain_category"] = "后缀名查询" + suffix_name = match_suffix(input_str) + dst_data["retrieve_keywords"] = suffix_name + dst_data["question_type"] = "" + suffix_to_software = retrieve_relevant_software(dst_data["retrieve_keywords"]) # 检索出后缀名指定的软件 + if isinstance(suffix_to_software, int): + dst_data["rewrite_result"] = "未查到相关知识" + elif isinstance(suffix_to_software, str): + dst_data["rewrite_result"] = f"{suffix_name}是什么文件?用什么软件打开?" + result = chain_suffix_answers.invoke({"query": input_str, "kg": suffix_to_software}) + dst_data["response"] = result + elif isinstance(suffix_to_software, list): + suffix_to_software_str = '\n'.join(suffix_to_software) + dst_data["rewrite_result"] = f"{suffix_name}是什么文件?用什么软件打开?" + result = chain_suffix_answers.invoke({"query": input_str, "kg": suffix_to_software_str}) + dst_data["response"] = result + else: + # 输入前预处理 + input_str_stoped = stop_word_processing(input_str, nlp_trf, polite_words) # 停用词处理 + input_str_syn = normalize_text(input_str_stoped, synonym_dict_2) # 同义词替换 + dst_data["query"] = input_str_syn + + # 垂直分类 + input_str_name_rewrite = chain_full_name_extension.invoke( + {"query": input_str_syn, "soft_name": software_name}) + vertical_category = chain_vertical.invoke(input_str_name_rewrite) + # print(vertical_category) + if vertical_category == '闲聊': + dst_data["response"] = "闲聊服务只提供给内测用户" + else: + dst_data["vertical_category"] = vertical_category + dst_data["intent_category_slot"]["intent_category"] = chain_intention.invoke(input_str_syn) + dst_data["question_type"] = chain_5W2H.invoke(input_str_syn) + if software_name: + dst_data["dialogue_status"] = 1 + dst_data["domain_category"] = software_name + dst_data["intent_category_slot"]["software_name_required"] = software_name + # 内容扩写(软件全名) + # input_str_name_rewrite = chain_full_name_extension.invoke( + # {"query": input_str_syn, "soft_name": software_name}) + _, demand = get_keywords_v4(input_str_name_rewrite) + dst_data["intent_category_slot"]["demand_name_required"] = demand + # print(demand) + retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3], + domain_mapping[dst_data["domain_category"]][3]) + index_keywords = interface_search(demand, *retrievers) + dst_data["retrieve_keywords"] = index_keywords + # print(index_keywords) + result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite, + "question_type": dst_data["question_type"], + "intention_type": dst_data["intent_category_slot"][ + "intent_category"], + "keywords": index_keywords}) + dst_data["rewrite_result"] = result_rewrite + else: + dst_data["domain_category"] = chain_domain_v3.invoke(input_str) + if dst_data["domain_category"] != '其他': + dst_data["dialogue_status"] = 1 + dst_data["intent_category_slot"]["software_name_required"] = software_name + # 内容扩写(软件全名) + input_str_name_rewrite = chain_full_name_extension.invoke( + {"query": input_str_stoped, "soft_name": software_name}) + print(input_str_name_rewrite) + _, demand = get_keywords_v4(input_str_name_rewrite) + if demand: + dst_data["intent_category_slot"]["demand_name_required"] = demand + # print(demand) + retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3], + domain_mapping[dst_data["domain_category"]][3]) + index_keywords = interface_search(demand, *retrievers) + dst_data["retrieve_keywords"] = index_keywords + # print(index_keywords) + result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite, + "question_type": dst_data["question_type"], + "intention_type": dst_data["intent_category_slot"][ + "intent_category"], + "keywords": index_keywords}) + dst_data["rewrite_result"] = result_rewrite + else: + dst_data["dialogue_status"] = 2 + dst_data["response"] = f"好的,具体软件是【{software_name}】, 请补充具体的需求" + else: + dst_data["dialogue_status"] = 2 + dst_data["response"] = f"好的,具体需求是【{input_str_syn}】, 但当前缺少具体软件名字,请补充" + return dst_data + # else: + # # 如果 history_chat 有数据,进行其他处理 + # _, B, _ = _chat_slice(data) + # if chains_mutil12.invoke(B) == '1': + # text = chat_record_conversion_demand(data) + # else: + # text = chat_record_conversion_software(data) + + # software_name = chain_domain_v3.invoke(text) + + # dst_data = { + # "query": "", # 用户输入 + # "vertical_category": "", # 垂直分类 + + # # 意图分类槽位结构 + # "intent_category_slot": { + # "intent_category": "", # 意图分类 + # "software_name_required": "", # 软件名称 + # "demand_name_required": "", # 需求内容 + # "temp_optional": "" # 占位 + # }, + + # "domain_category": "", # 领域分类 + # "question_type": "", # 5W2H分类 + # "retrieve_keywords": "", # 检索语义 + # "dialogue_status": None, # 对话状态(单轮/多轮 1/2) + + # "response": "", # 回复 + # # 多轮对话状态 + # "multi_dialogue_state": { + # "status": None, # 0: rewrite 1: response + # "multi_intent_category_slot": { + # "multi_intent_category": "", # 意图分类 + # "multi_software_name_required": "", # 软件名称 + # "multi_demand_name_required": "", # 需求内容 + # "multi_temp_optional": "" # 占位 + # }, + # "multi_domain_category": "", # 领域分类 + # "question_type": "", # 5W2H分类 + # "retrieve_keywords": "" # 检索语义 + # }, + + # "rewrite_result": "" # 改写结果 + # } + + # dst_data["query"] = text + # dst_data["dialogue_status"] = 1 + # vertical_category = chain_vertical.invoke(dst_data["query"]) + # if vertical_category == '闲聊': + # dst_data["response"] = "闲聊服务只提供给内测用户" + # else: + # vertical_category = chain_vertical.invoke(dst_data["query"]) + # dst_data["intent_category_slot"]["intent_category"] = chain_intention.invoke(dst_data["query"]) + # dst_data["question_type"] = chain_5W2H.invoke(dst_data["query"]) + # dst_data["domain_category"] = software_name + # dst_data["intent_category_slot"]["software_name_required"] = software_name + # input_str_name_rewrite = chain_full_name_extension.invoke({"query": dst_data["query"], "soft_name": software_name}) + # _, demand = get_keywords_v4(input_str_name_rewrite) + # dst_data["intent_category_slot"]["demand_name_required"] = demand + # retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3], + # domain_mapping[dst_data["domain_category"]][3]) + # index_keywords = interface_search(demand, *retrievers) + # dst_data["retrieve_keywords"] = index_keywords + # result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite, + # "question_type": dst_data["question_type"], + # "intention_type": dst_data["intent_category_slot"][ + # "intent_category"], + # "keywords": index_keywords}) + # dst_data["rewrite_result"] = result_rewrite + # return dst_data + +# 可选:本地调试入口 +if __name__ == "__main__": + uvicorn.run(app, host="0.0.0.0", port=3334) + diff --git a/kg_management.py b/kg_management.py index 3a7893b..2afd573 100644 --- a/kg_management.py +++ b/kg_management.py @@ -1,51 +1,55 @@ -""" -=================================== -@Auther:WenZ -@Company: BooWay -@project:booway_dm -=================================== -""" -from vector_load import Mixed_retrieval, interface_search - -xizang_input_path = '../data/temp/Z1_keywords_software_usage.txt' -retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3 = Mixed_retrieval(xizang_input_path) - -cuceng_input_path = '../data/temp/C1_keywords_software_usage.txt' -retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6 = Mixed_retrieval(cuceng_input_path) - -jigai_input_path = '../data/temp/T1_keywords_software_usage.txt' -retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9 = Mixed_retrieval(jigai_input_path) - -from vector_load import Building_search_dictionary, Official_website_kg_search - -input_index_csv_path = "/data/Z_project/rahulnyk/DM_lab/data/temp/data_index.csv" -xizang_input_csv_path = "/data/Z_project/rahulnyk/DM_lab/data/temp/西藏造价FAQ数据集.csv" -cuceng_input_csv_path = "/data/Z_project/rahulnyk/DM_lab/data/temp/新型储能电站建设计价通C1.csv" -jigai_input_csv_path = "/data/Z_project/rahulnyk/DM_lab/data/temp/技改检修计价通T1.csv" - - -def process_domain_category(nlu_info, retrievers, input_csv_path, index_csv_path): - """ - 处理不同领域类别的检索逻辑。 - - :param nlu_info: 包含领域类别和检索关键词的对象 - :param retrievers: 用于检索的 FAISS 索引列表 - :param input_csv_path: 领域特定的 CSV 输入路径 - :param index_csv_path: 索引 CSV 路径 - :return: 处理后的 QA RAG 结果列表 - """ - index_keywords = interface_search(nlu_info.retrieve_keywords, *retrievers) - qa_rag = [] - - for keyword in index_keywords: - # todo: bug修改: 避免output_id为None情况 - output_path, output_id = Building_search_dictionary(input_csv_path, index_csv_path, keyword) - if output_path is not None or output_id is not None: - qa_rag.append(f"检索知识:{output_path}") - qa_rag.append(Official_website_kg_search(output_id)) - else: - continue - - return qa_rag - - +""" +=================================== +@Auther:WenZ +@Company: BooWay +@project:booway_dm +=================================== +""" +from vector_load import Mixed_retrieval, interface_search + +xizang_input_path = '../data/temp/Z1_keywords_software_usage.txt' +retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3 = Mixed_retrieval(xizang_input_path) + +cuceng_input_path = '../data/temp/C1_keywords_software_usage.txt' +retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6 = Mixed_retrieval(cuceng_input_path) + +jigai_input_path = '../data/temp/T1_keywords_software_usage.txt' +retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9 = Mixed_retrieval(jigai_input_path) + +peiwang_input_path = '../data/temp/D3_keywords_software_usage.txt' +retriever_txt_faiss10, retriever_txt_faiss11, retriever_txt_faiss12 = Mixed_retrieval(peiwang_input_path) + +from vector_load import Building_search_dictionary, Official_website_kg_search + +input_index_csv_path = "/data/Z_project/rahulnyk/DM_lab/data/temp/data_index.csv" +xizang_input_csv_path = "/data/Z_project/rahulnyk/DM_lab/data/temp/西藏造价FAQ数据集.csv" +cuceng_input_csv_path = "/data/Z_project/rahulnyk/DM_lab/data/temp/新型储能电站建设计价通C1.csv" +jigai_input_csv_path = "/data/Z_project/rahulnyk/DM_lab/data/temp/技改检修计价通T1.csv" +peiwang_input_csv_path = "/data/Z_project/rahulnyk/DM_lab/data/temp/博微配网工程计价通D3软件.csv" + + +def process_domain_category(nlu_info, retrievers, input_csv_path, index_csv_path): + """ + 处理不同领域类别的检索逻辑。 + + :param nlu_info: 包含领域类别和检索关键词的对象 + :param retrievers: 用于检索的 FAISS 索引列表 + :param input_csv_path: 领域特定的 CSV 输入路径 + :param index_csv_path: 索引 CSV 路径 + :return: 处理后的 QA RAG 结果列表 + """ + index_keywords = interface_search(nlu_info.retrieve_keywords, *retrievers) + qa_rag = [] + + for keyword in index_keywords: + # todo: bug修改: 避免output_id为None情况 + output_path, output_id = Building_search_dictionary(input_csv_path, index_csv_path, keyword) + if output_path is not None or output_id is not None: + qa_rag.append(f"检索知识:{output_path}") + qa_rag.append(Official_website_kg_search(output_id)) + else: + continue + + return qa_rag + + diff --git a/main.py b/main.py index ffc3d23..9f18729 100644 --- a/main.py +++ b/main.py @@ -1,105 +1,105 @@ -""" -=================================== -@Auther:WenZ -@Company: BooWay -@project:booway_dm -=================================== -""" -from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, ChatRecord -from chains_ceshi import Vertical_classification, small_talk, intention_judge, domain_judge, judge_5W2H, extract_keywords, answer_questions, domain_judge, domain_judge_v2 - -from kg_management import retriever_txt_faiss1 -from kg_management import retriever_txt_faiss2 -from kg_management import retriever_txt_faiss3 -from kg_management import retriever_txt_faiss4 -from kg_management import retriever_txt_faiss5 -from kg_management import retriever_txt_faiss6 -from kg_management import retriever_txt_faiss7 -from kg_management import retriever_txt_faiss8 -from kg_management import retriever_txt_faiss9 - -from kg_management import input_index_csv_path -from kg_management import xizang_input_csv_path -from kg_management import cuceng_input_csv_path -from kg_management import jigai_input_csv_path - -from kg_management import process_domain_category - -chain_vertical = Vertical_classification() # 垂直/开放分类 -chain_intention = intention_judge() # 意图分类 -chain_domain = domain_judge() # 领域分类 -chain_5w2h = judge_5W2H() # 5W2H分类 -chain_keywords = extract_keywords() # 关键语义提取 -chains_qa = answer_questions() # 检索回答 - -chain_domain_v2 = domain_judge_v2() - -domain_mapping = { - '西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path), - '新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path), - '技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path) -} - -while True: - input_str = input("用户:") - - if input_str.lower() in ["退出", "bye", "再见"]: - print("booway软件助手:再见") - break - - vertical_category = chain_vertical.invoke(input_str) - - if vertical_category == '闲聊': - print("booway软件助手:闲聊服务只提供给内测用户") - continue - - elif vertical_category == '软件咨询': - # 1. DST:NLUInfo初始化 - nlu_info = NLUInfo(vertical_category="软件咨询") - - # 2. DST: 更新NLUInfo - intention_info = chain_intention.invoke(input_str) - nlu_info.intent_category = intention_info - - domain_info = chain_domain.invoke(input_str) - nlu_info.domain_category = domain_info - - # 多轮对话逻辑 - while nlu_info.domain_category == '未知': - print("booway软件助手: 请问是什么软件?") - software_name = input("用户: ").strip() - # domain_info = chain_domain_v2.invoke(software_name) - print(software_name) - nlu_info.domain_category = software_name - if nlu_info.domain_category not in domain_mapping: - print("booway软件助手:请输入一个有效的软件名称") - - # 多轮对话触发条件:领域信息缺失 - # print(f"当前的 domain_info: {domain_info}") - # print(f"当前的 nlu_info.domain_category: {nlu_info.domain_category}") - # print(f"当前的 domain_mapping.keys(): {list(domain_mapping.keys())}") - - _5w2h_info = chain_5w2h.invoke(input_str) - nlu_info.question_type = _5w2h_info - - keywords_info = chain_keywords.invoke({"software_name": domain_info, - "_5w2h_type": _5w2h_info, - "query": input_str}) - nlu_info.retrieve_keywords = keywords_info - - # 3. 检索分析 - retrievers, input_csv_path = domain_mapping[nlu_info.domain_category][:3], \ - domain_mapping[nlu_info.domain_category][3] - kg = process_domain_category(nlu_info, retrievers, input_csv_path, input_index_csv_path) - # print(kg) - if kg != []: - extracted_index = [item[5:] for item in kg if item.startswith("检索知识")] - print(f"系统:检索到的知识为{extracted_index},分析如下:\n") - print(chains_qa.invoke({"query": input_str, "kg": '\n'.join(kg)})) - else: - print("无相关检索知识,请重新提出问题") - - else: - print("booway软件助手:抱歉,我不明白你的问题,请尝试重新表述") - continue - +""" +=================================== +@Auther:WenZ +@Company: BooWay +@project:booway_dm +=================================== +""" +from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, ChatRecord +from chains_ceshi import Vertical_classification, small_talk, intention_judge, domain_judge, judge_5W2H, extract_keywords, answer_questions, domain_judge, domain_judge_v2 + +from kg_management import retriever_txt_faiss1 +from kg_management import retriever_txt_faiss2 +from kg_management import retriever_txt_faiss3 +from kg_management import retriever_txt_faiss4 +from kg_management import retriever_txt_faiss5 +from kg_management import retriever_txt_faiss6 +from kg_management import retriever_txt_faiss7 +from kg_management import retriever_txt_faiss8 +from kg_management import retriever_txt_faiss9 + +from kg_management import input_index_csv_path +from kg_management import xizang_input_csv_path +from kg_management import cuceng_input_csv_path +from kg_management import jigai_input_csv_path + +from kg_management import process_domain_category + +chain_vertical = Vertical_classification() # 垂直/开放分类 +chain_intention = intention_judge() # 意图分类 +chain_domain = domain_judge() # 领域分类 +chain_5w2h = judge_5W2H() # 5W2H分类 +chain_keywords = extract_keywords() # 关键语义提取 +chains_qa = answer_questions() # 检索回答 + +chain_domain_v2 = domain_judge_v2() + +domain_mapping = { + '西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path), + '新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path), + '技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path) +} + +while True: + input_str = input("用户:") + + if input_str.lower() in ["退出", "bye", "再见"]: + print("booway软件助手:再见") + break + + vertical_category = chain_vertical.invoke(input_str) + + if vertical_category == '闲聊': + print("booway软件助手:闲聊服务只提供给内测用户") + continue + + elif vertical_category == '软件咨询': + # 1. DST:NLUInfo初始化 + nlu_info = NLUInfo(vertical_category="软件咨询") + + # 2. DST: 更新NLUInfo + intention_info = chain_intention.invoke(input_str) + nlu_info.intent_category = intention_info + + domain_info = chain_domain.invoke(input_str) + nlu_info.domain_category = domain_info + + # 多轮对话逻辑 + while nlu_info.domain_category == '未知': + print("booway软件助手: 请问是什么软件?") + software_name = input("用户: ").strip() + # domain_info = chain_domain_v2.invoke(software_name) + print(software_name) + nlu_info.domain_category = software_name + if nlu_info.domain_category not in domain_mapping: + print("booway软件助手:请输入一个有效的软件名称") + + # 多轮对话触发条件:领域信息缺失 + # print(f"当前的 domain_info: {domain_info}") + # print(f"当前的 nlu_info.domain_category: {nlu_info.domain_category}") + # print(f"当前的 domain_mapping.keys(): {list(domain_mapping.keys())}") + + _5w2h_info = chain_5w2h.invoke(input_str) + nlu_info.question_type = _5w2h_info + + keywords_info = chain_keywords.invoke({"software_name": domain_info, + "_5w2h_type": _5w2h_info, + "query": input_str}) + nlu_info.retrieve_keywords = keywords_info + + # 3. 检索分析 + retrievers, input_csv_path = domain_mapping[nlu_info.domain_category][:3], \ + domain_mapping[nlu_info.domain_category][3] + kg = process_domain_category(nlu_info, retrievers, input_csv_path, input_index_csv_path) + # print(kg) + if kg != []: + extracted_index = [item[5:] for item in kg if item.startswith("检索知识")] + print(f"系统:检索到的知识为{extracted_index},分析如下:\n") + print(chains_qa.invoke({"query": input_str, "kg": '\n'.join(kg)})) + else: + print("无相关检索知识,请重新提出问题") + + else: + print("booway软件助手:抱歉,我不明白你的问题,请尝试重新表述") + continue + diff --git a/streamlit_main.py b/streamlit_main.py index bbb55fb..24f7548 100644 --- a/streamlit_main.py +++ b/streamlit_main.py @@ -1,268 +1,178 @@ -""" -=================================== -@Auther:WenZ -@Company: BooWay -@project:booway_dm -=================================== -""" -import streamlit as st -from utils import judge_define_suffix, match_suffix, retrieve_relevant_software -from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, \ - ChatRecord -from utils import stop_word_processing -import spacy -import zh_core_web_sm, zh_core_web_md, zh_core_web_lg, zh_core_web_trf -from utils import get_keywords, get_keywords_v2, get_keywords_v3 -from vector_load import interface_search - -# nlp_sm = zh_core_web_sm.load() -# nlp_md = zh_core_web_md.load() -# nlp_lg = zh_core_web_lg.load() -nlp_trf = zh_core_web_trf.load() - -polite_words = {"你好", "您好", "请", "请问", "谢谢", "不客气", "麻烦", "打扰", "拜托", "辛苦", "劳驾"} - -from chains_ceshi import suffix_answers -from chains_ceshi import Vertical_classification -from chains_ceshi import intention_judge -from chains_ceshi import domain_judge -from chains_ceshi import judge_5W2H - -from chains_rewrite import software_name_rewrite -from chains_rewrite import query_function_rewrite -from chains_rewrite import operation_guidance_rewrite -from chains_rewrite import troubleshooting_rewrite -from chains_rewrite import access_rewrite - -from kg_management import retriever_txt_faiss1 -from kg_management import retriever_txt_faiss2 -from kg_management import retriever_txt_faiss3 -from kg_management import retriever_txt_faiss4 -from kg_management import retriever_txt_faiss5 -from kg_management import retriever_txt_faiss6 -from kg_management import retriever_txt_faiss7 -from kg_management import retriever_txt_faiss8 -from kg_management import retriever_txt_faiss9 - -from kg_management import input_index_csv_path -from kg_management import xizang_input_csv_path -from kg_management import cuceng_input_csv_path -from kg_management import jigai_input_csv_path - -from kg_management import process_domain_category - -domain_mapping = { - '西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path), - '新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path), - '技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path) -} - -chain_suffix_answers = suffix_answers() # 后缀名问题处理 -chain_vertical = Vertical_classification() # 垂直/开放分类 -chain_intention = intention_judge() # 意图分类 -chain_domain = domain_judge() # 领域分类 -chain_5W2H = judge_5W2H() # 5W2H分类 - -chains_name_rewrite = software_name_rewrite() # 问题改写:软件名改写 -chain_function_rewrite = query_function_rewrite() # 问题改写:软件功能查询 -chain_guidance_rewrite = operation_guidance_rewrite() # 问题改写:软件操作指导 -chain_troubleshooting_rewrite = troubleshooting_rewrite() # 问题改写:软件故障排查类 -chain_access_rewrite = access_rewrite() # 问题改写:软件下载与安装 - -from chains_rewrite import to_normal_rewrite -from chains_rewrite import retrieval_rewrite - -chain_normal_rewrite = to_normal_rewrite() -chain_retrieval_rewrite = retrieval_rewrite() - -# 同义词替换 -from utils import normalize_text - -synonym_dict = { - "西藏造价软件Z1": ["西藏造价Z1", "Z1软件", "西藏Z1", "西藏工程造价Z1", "西藏造价通Z1", "Z1", "西藏", - "西藏造价z1", "z1软件", "西藏z1", "西藏工程造价z1", "西藏造价通z1", "z1", ], - "新型储能计价通C1": ["储能计价通C1", "C1储能计价", "新型储能C1", "储能计价C1", "新型储能计价通", "C1", "储能", - "储能计价通c1", "c1储能计价", "新型储能c1", "储能计价c1", "新型储能计价通", "c1"], - "技改检修计价通T1": ["技改检修T1", "T1检修计价", "技改计价通T1", "技改检修计价通", "T1技改检修", "T1", "技改", - "技改检修t1", "t1检修计价", "技改计价通t1", "技改检修计价通", "t1技改检修", "t1"], - "费率": ["费费率"], - "下载": ["获取", "安装", "下载下来", "装上"] -} - -import random - -reponse_prompts = [ - "🤔 Booway软件助手:请问您指的是哪个软件?", - "🤔 Booway软件助手:请提供软件名称,以便更好地帮助您。", - "🤔 Booway软件助手:请问您使用的是什么软件?", - "🤔 Booway软件助手:请告诉我您要查询的软件名称。", - "🤔 Booway软件助手:请问是哪款软件?" -] - -import streamlit as st -import random - -st.set_page_config(page_title="Booway 助手", layout="wide") - -# st.title("🤖 Booway 助手") -# 助手简介 -st.markdown(""" -# 🤖 Booway 软件助手 - -欢迎使用 **booway 软件助手**,这是一个用于协助用户进行电力造价软件相关问题咨询的智能系统。 - -**目前可咨询软件为:** -- 西藏造价软件Z1 -- 新型储能计价通C1 -- 技改检修计价通T1 -- 后缀名文件咨询 - -**使用方法:** -直接在下方输入你的问题,例如: -- 你好,想问下储能的C1那个软件。初设的基本预备费费率想调整一下,但是没有找到能调整的地方 -- 如何把西藏老定额工程升级成西藏Z1的新定额工程 -- 储能软件勾选了卸车,总价不变呢 -- bjgx用什么软件打开的? -- 设备运杂费率怎么设置 (多轮测试) -- 你好,初设的基本预备费费率想调整一下,但是没有找到能调整的地方(多轮测试) - -**注意:多轮对话** -- ~~目前多轮对话中 当机器人询问用户什么软件,则必须是以上软件名字~~ - -如果你输入的是闲聊内容,系统将提示仅内测用户可用。 -""") - - -input_str = st.text_input("🦉 用户:", "") - -if input_str: - if judge_define_suffix(input_str): - nlu_info = NLUInfo(vertical_category="软件咨询") - nlu_info.intent_category = "查询功能" - nlu_info.domain_category = "后缀名查询" - - suffix_name = match_suffix(input_str) - nlu_info.retrieve_keywords = suffix_name - suffix_to_software = retrieve_relevant_software(suffix_name) - - if isinstance(suffix_to_software, int): - st.info("Booway 助手:未查到相关知识") - - elif isinstance(suffix_to_software, str): - query_rewrite = f"{suffix_name}是什么文件?用什么软件打开?" - nlu_info.rewrite = query_rewrite - query_kg = suffix_to_software - result = chain_suffix_answers.invoke({"query": input_str, "kg": query_kg}) - st.subheader("识别出的NLU信息") - st.json({"垂直分类": nlu_info.vertical_category, - "意图分类": nlu_info.intent_category, - "领域分类": nlu_info.domain_category, - "问题分类": nlu_info.question_type, - "检索语义": nlu_info.retrieve_keywords, - "改写结果": nlu_info.rewrite, - "检索回答": result}) - - elif isinstance(suffix_to_software, list): - suffix_to_software_str = '\n'.join(suffix_to_software) - query_rewrite = f"{suffix_name}是什么文件?用什么软件打开?" - nlu_info.rewrite = query_rewrite - query_kg = suffix_to_software_str - result = chain_suffix_answers.invoke({"query": input_str, "kg": query_kg}) - st.subheader("识别出的NLU信息") - st.json({"垂直分类": nlu_info.vertical_category, - "意图分类": nlu_info.intent_category, - "领域分类": nlu_info.domain_category, - "问题分类": nlu_info.question_type, - "检索语义": nlu_info.retrieve_keywords, - "改写结果": nlu_info.rewrite, - "检索回答": result}) - else: - # 多轮对话处理 - # 第一步:预处理输入 - input_str_stoped = stop_word_processing(input_str, nlp_trf, polite_words) - input_str_syn = normalize_text(input_str_stoped, synonym_dict) - - # 第二步:调用分类链判断领域 - vertical_category = chain_domain.invoke(input_str_syn) - - # 第三步:若为未知领域,尝试引导补充软件名 - if vertical_category == "未知": - # 初始化状态变量 - if "mt_input_done" not in st.session_state: - st.session_state.mt_input_done = False - if "mt_input_value" not in st.session_state: - st.session_state.mt_input_value = "" - if "mt_prompt" not in st.session_state: - st.session_state.mt_prompt = random.choice(reponse_prompts) - - # 尚未完成补充输入 - if not st.session_state.mt_input_done: - mt_conversation = st.text_input( - f"{st.session_state.mt_prompt}", key="mt_input" - ) - - if mt_conversation: - st.session_state.mt_input_value = mt_conversation - st.session_state.mt_input_done = True - st.rerun() - else: - st.stop() # 等待输入 - - # 完成补充输入,重新获取 vertical_category - mt_input_str = normalize_text(st.session_state.mt_input_value, synonym_dict) - vertical_category = chain_domain.invoke(mt_input_str) - - # 第四步:NLU构建 - nlu_info = NLUInfo(vertical_category="软件咨询") - nlu_info.domain_category = vertical_category - - # 第五步:改写问题并提取关键词 - input_str_name_rewrite = chains_name_rewrite.invoke({ - "query": input_str, - "software_name": nlu_info.domain_category - }) - input_str_rewrite = chain_normal_rewrite.invoke(input_str_name_rewrite) - - temp_retriever = get_keywords_v2(input_str_rewrite) - nlu_info.question_type = chain_5W2H.invoke(input_str_rewrite) - nlu_info.intent_category = chain_intention.invoke(input_str) - - # 第六步:调用检索器并重写查询 - retrievers, input_csv_path = ( - domain_mapping[nlu_info.domain_category][:3], - domain_mapping[nlu_info.domain_category][3] - ) - - index_keywords = interface_search(temp_retriever, *retrievers) - st.info(f"提取关键词:{index_keywords}") - - nlu_info.rewrite = chain_retrieval_rewrite.invoke({ - "query": input_str_rewrite, - "question_type": nlu_info.question_type, - "intention_type": nlu_info.intent_category, - "keywords": index_keywords - }) - - nlu_info.retrieve_keywords = get_keywords_v3(nlu_info.rewrite) - - # 第七步:展示识别结果 - st.subheader("识别出的NLU信息") - st.json({"垂直分类":nlu_info.vertical_category, - "意图分类":nlu_info.intent_category, - "领域分类":nlu_info.domain_category, - "问题分类":nlu_info.question_type, - "检索语义":nlu_info.retrieve_keywords, - "改写结果":nlu_info.rewrite}) - - for key in ["mt_input_done", "mt_input_value", "mt_prompt", "mt_input"]: - if key in st.session_state: - del st.session_state[key] - - - - - # streamlit run streamlit_main.py --server.port 2335 - - - - +""" +=================================== +@Auther:WenZ +@Company: BooWay +@project:booway_dm +=================================== +""" +import streamlit as st +from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, ChatRecord +from chains_ceshi import Vertical_classification, small_talk, intention_judge, domain_judge, judge_5W2H, extract_keywords, answer_questions + +from kg_management import retriever_txt_faiss1 +from kg_management import retriever_txt_faiss2 +from kg_management import retriever_txt_faiss3 +from kg_management import retriever_txt_faiss4 +from kg_management import retriever_txt_faiss5 +from kg_management import retriever_txt_faiss6 +from kg_management import retriever_txt_faiss7 +from kg_management import retriever_txt_faiss8 +from kg_management import retriever_txt_faiss9 + +from kg_management import input_index_csv_path +from kg_management import xizang_input_csv_path +from kg_management import cuceng_input_csv_path +from kg_management import jigai_input_csv_path + +from kg_management import process_domain_category + +chain_vertical = Vertical_classification() # 垂直/开放分类 +chain_intention = intention_judge() # 意图分类 +chain_domain = domain_judge() # 领域分类 +chain_5w2h = judge_5W2H() # 5W2H分类 +chain_keywords = extract_keywords() # 关键语义提取 +chains_qa = answer_questions() # 检索回答 + +# import streamlit as st +# +# chain_vertical = Vertical_classification() # 垂直/开放分类 +# chain_intention = intention_judge() # 意图分类 +# chain_domain = domain_judge() # 领域分类 +# chain_5w2h = judge_5W2H() # 5W2H分类 +# chain_keywords = extract_keywords() # 关键语义提取 +# chains_qa = answer_questions() # 检索回答 + +domain_mapping = { + '西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path), + '新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path), + '技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path) +} + +import streamlit as st + +# 假设这些是你已有的模块 +# from your_module import chain_vertical, chain_intention, chain_domain, chain_5w2h, chain_keywords, chains_qa +# from your_module import process_domain_category, NLUInfo, domain_mapping, input_index_csv_path + +# 页面配置 +st.set_page_config(page_title="booway 软件助手", layout="wide") + +# 助手简介 +st.markdown(""" +# 🤖 booway 软件助手 + +欢迎使用 **booway 软件助手**,这是一个用于协助用户进行电力造价软件相关问题咨询的智能系统。 + +**目前可咨询软件为:** +- 西藏造价软件Z1 +- 新型储能计价通C1 +- 技改检修计价通T1 + +**使用方法:** +直接在下方输入你的问题,例如: +- “你好,想问下储能的C1那个软件。初设的基本预备费费率想调整一下,但是没有找到能调整的地方” +- “你好,初设的基本预备费费率想调整一下,但是没有找到能调整的地方”(多轮测试) +- “如何把西藏老定额工程升级成西藏Z1的新定额工程” +- “储能软件勾选了卸车,总价不变呢” +- “请问技改检修软件里其他费中的设计费的取费基数和费率,你们设置时肯定要依据吧” + +**注意:多轮对话** +- 目前多轮对话中 当机器人询问用户什么软件,则必须是以上软件名字 + +**注意:垂直分类不对*** +- 技改拆除的安全文明施工费费率是多少 分类: 闲聊 +- 技改软件拆除的安全文明施工费费率是多少 分类:软件咨询(+ 软件) + - 如:技改项目从老版本升级到新版本,是不是定额也跟着跟新了 + - 改为:技改软件项目从老版本升级到新版本,是不是定额也跟着跟新了 + +如果你输入的是闲聊内容,系统将提示仅内测用户可用。 +""") + +# 用户输入 +user_input = st.text_input("👤 用户:", "") + +if user_input: + if user_input.lower() in ["退出", "bye", "再见"]: + st.success("booway软件助手:再见 👋") + else: + vertical_category = chain_vertical.invoke(user_input) + + if vertical_category == '闲聊': + st.warning("booway软件助手:闲聊服务只提供给内测用户") + elif vertical_category == '软件咨询': + # 初始化 NLUInfo + nlu_info = NLUInfo(vertical_category="软件咨询") + + # 意图识别 & 领域识别 + nlu_info.intent_category = chain_intention.invoke(user_input) + domain_info = chain_domain.invoke(user_input) + nlu_info.domain_category = domain_info + + # 多轮询问软件名称 + if domain_info == '未知': + import random + prompts = [ + "🤔 Booway软件助手:请问您指的是哪个软件?", + "🤔 Booway软件助手:请提供软件名称,以便更好地帮助您。", + "🤔 Booway软件助手:请问您使用的是什么软件?", + "🤔 Booway软件助手:请告诉我您要查询的软件名称。", + "🤔 Booway软件助手:请问是哪款软件?" + ] + + # 随机选择一个提示 + random_prompt = random.choice(prompts) + + # 生成输入框 + software_name = st.text_input(random_prompt, "") + # software_name = st.text_input("🤔 booway软件助手:请问是什么软件?", "") + if software_name and software_name in domain_mapping: + domain_info = software_name + nlu_info.domain_category = software_name + elif software_name: + st.error("booway软件助手:请输入一个有效的软件名称") + + if nlu_info.domain_category in domain_mapping: + # 问题类型 + 关键词提取 + nlu_info.question_type = chain_5w2h.invoke(user_input) + nlu_info.retrieve_keywords = chain_keywords.invoke({ + "software_name": nlu_info.domain_category, + "_5w2h_type": nlu_info.question_type, + "query": user_input + }) + + # 检索分析 + retrievers, input_csv_path = (domain_mapping[nlu_info.domain_category][:3], + domain_mapping[nlu_info.domain_category][3]) + kg = process_domain_category(nlu_info, retrievers, input_csv_path, input_index_csv_path) + + if kg: + # 检索内容 + extracted_index = [item[5:] for item in kg if item.startswith("检索知识")] + st.subheader("📚 检索内容") + st.write(extracted_index) + + # 回答内容 + response = chains_qa.invoke({"query": user_input, "kg": '\n'.join(kg)}) + st.subheader("💬 回答内容") + st.write(response) + + # NLU信息 + st.subheader("🧠 NLU_info") + st.json({ + "vertical_category": nlu_info.vertical_category, + "intent_category": nlu_info.intent_category, + "domain_category": nlu_info.domain_category, + "question_type": nlu_info.question_type, + "retrieve_keywords": nlu_info.retrieve_keywords, + }) + else: + st.error("booway软件助手:无相关检索知识,请重新提出问题") + else: + st.warning("booway软件助手:抱歉,我不明白你的问题,请尝试重新表述") + + +# streamlit run streamlit_main.py --server.port 2335 + + + +