""" =================================== @Auther:WenZ @Company: BooWay @project:booway_dm =================================== """ """ =================================== @Auther:WenZ @Company: BooWay @project:booway_dm =================================== """ from utils import judge_define_suffix, match_suffix, retrieve_relevant_software from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, \ ChatRecord from utils import stop_word_processing import spacy import zh_core_web_sm, zh_core_web_md, zh_core_web_lg, zh_core_web_trf from utils import get_keywords, get_keywords_v2, get_keywords_v3, get_keywords_v4 from vector_load import interface_search # nlp_sm = zh_core_web_sm.load() # nlp_md = zh_core_web_md.load() # nlp_lg = zh_core_web_lg.load() nlp_trf = zh_core_web_trf.load() polite_words = {"你好", "您好", "请", "请问", "谢谢", "不客气", "麻烦", "打扰", "拜托", "辛苦", "劳驾"} from chains_ceshi import suffix_answers from chains_ceshi import Vertical_classification from chains_ceshi import intention_judge from chains_ceshi import domain_judge from chains_ceshi import judge_5W2H from chains_rewrite import software_name_rewrite from chains_rewrite import query_function_rewrite from chains_rewrite import operation_guidance_rewrite from chains_rewrite import troubleshooting_rewrite from chains_rewrite import access_rewrite from kg_management import retriever_txt_faiss1 from kg_management import retriever_txt_faiss2 from kg_management import retriever_txt_faiss3 from kg_management import retriever_txt_faiss4 from kg_management import retriever_txt_faiss5 from kg_management import retriever_txt_faiss6 from kg_management import retriever_txt_faiss7 from kg_management import retriever_txt_faiss8 from kg_management import retriever_txt_faiss9 from kg_management import retriever_txt_faiss10 from kg_management import retriever_txt_faiss11 from kg_management import retriever_txt_faiss12 from kg_management import input_index_csv_path from kg_management import xizang_input_csv_path from kg_management import cuceng_input_csv_path from kg_management import jigai_input_csv_path from kg_management import peiwang_input_path from kg_management import process_domain_category domain_mapping = { '西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path), '新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path), '技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path), '博微配网工程计价通D3软件': (retriever_txt_faiss10, retriever_txt_faiss11, retriever_txt_faiss12, peiwang_input_path) } chain_suffix_answers = suffix_answers() # 后缀名问题处理 chain_vertical = Vertical_classification() # 垂直/开放分类 chain_intention = intention_judge() # 意图分类 chain_domain = domain_judge() # 领域分类 chain_5W2H = judge_5W2H() # 5W2H分类 chains_name_rewrite = software_name_rewrite() # 问题改写:软件名改写 chain_function_rewrite = query_function_rewrite() # 问题改写:软件功能查询 chain_guidance_rewrite = operation_guidance_rewrite() # 问题改写:软件操作指导 chain_troubleshooting_rewrite = troubleshooting_rewrite() # 问题改写:软件故障排查类 chain_access_rewrite = access_rewrite() # 问题改写:软件下载与安装 from chains_rewrite import to_normal_rewrite from chains_rewrite import retrieval_rewrite chain_normal_rewrite = to_normal_rewrite() chain_retrieval_rewrite = retrieval_rewrite() from chains_rewrite import full_name_extension chain_full_name_extension = full_name_extension() from chains_ceshi import domain_judge_v3 chain_domain_v3 = domain_judge_v3() # 同义词替换 from utils import normalize_text synonym_dict = { "西藏造价软件Z1": ["西藏造价Z1", "Z1软件", "西藏Z1", "西藏工程造价Z1", "西藏造价通Z1", "Z1", "西藏造价z1", "z1软件", "西藏z1", "西藏工程造价z1", "西藏造价通z1", "z1",], "新型储能计价通C1": ["储能计价通C1", "C1储能计价", "新型储能C1", "储能计价C1", "新型储能计价通", "C1", "储能计价通c1", "c1储能计价", "新型储能c1", "储能计价c1", "新型储能计价通", "c1"], "技改检修计价通T1": ["技改检修T1", "T1检修计价", "技改计价通T1", "技改检修计价通", "T1技改检修", "T1", "技改检修t1", "t1检修计价", "技改计价通t1", "技改检修计价通", "t1技改检修", "t1"], "博微配网工程计价通D3软件": ["博微配网D3", "D3配网计价", "配网计价D3", "配网D3", "D3软件", "博微D3", "D3", "博微配网工程计价通", "博微配网工程计价通D3", "博微配网d3", "d3配网计价", "配网计价d3", "配网d3", "d3软件", "博微d3", "d3"] } synonym_dict_2 = { "费率": ["费费率"], "下载": ["获取", "安装", "下载下来", "装上"] } from typing import Optional dst_data = { "query": "", # 用户输入 "vertical_category": "", # 垂直分类 # 意图分类槽位结构 "intent_category_slot": { "intent_category": "", # 意图分类 "software_name_required": "", # 软件名称 "demand_name_required": "", # 需求内容 "temp_optional": "" # 占位 }, "domain_category": "", # 领域分类 "question_type": "", # 5W2H分类 "retrieve_keywords": "", # 检索语义 "dialogue_status": None, # 对话状态(单轮/多轮 1/2) "response": "", # 回复 # 多轮对话状态 "multi_dialogue_state": { "status": None, # 0: rewrite 1: response "multi_intent_category_slot": { "multi_intent_category": "", # 意图分类 "multi_software_name_required": "", # 软件名称 "multi_demand_name_required": "", # 需求内容 "multi_temp_optional": "" # 占位 }, "multi_domain_category": "", # 领域分类 "question_type": "", # 5W2H分类 "retrieve_keywords": "" # 检索语义 }, "rewrite_result": "" # 改写结果 } def get_software_name(input_json): try: # 尝试访问 "soft_info" 下的 "软件名称" software_name = input_json["content"]["soft_info"]["software_name"] return software_name except KeyError: # 如果 "soft_info" 或 "软件名称" 不存在,返回 None return None def chat_record_conversion_demand(data): A, B, C = _chat_slice(data) C = normalize_text(C, synonym_dict) C = chain_software_judge.invoke(C) def _chat_aggregation(A, B, C): full_dialog = f"user: {A}\nassistant: {B}\nuser: {C}" return full_dialog full_dialog = _chat_aggregation(A, B, C) return chain_mutil_text_rewrite.invoke(full_dialog) def _chat_slice(data): history_chat = data["content"].get("history_chat") input_str = data["input_str"] A = history_chat[0].get("user", "") B = history_chat[1].get("assistant", "") C = input_str return A, B, C def chat_record_conversion_software(data): A, B, C = _chat_slice(data) A = normalize_text(A, synonym_dict) A = chain_software_judge.invoke(A) def _chat_aggregation(A, B, C): full_dialog = f"user: {A}\nassistant: {B}\nuser: {C}" return full_dialog full_dialog = _chat_aggregation(A, B, C) return chain_mutil_text_rewrite.invoke(full_dialog) from chains_rewrite import software_judge chain_software_judge = software_judge() from chains_rewrite import mutil_text_rewrite chain_mutil_text_rewrite = mutil_text_rewrite() from chains_ceshi import mutil12 chains_mutil12 = mutil12() from fastapi import FastAPI, Request from pydantic import BaseModel from typing import List, Union import uvicorn app = FastAPI() from typing import Dict, List, Optional, Any @app.post("/analyze") async def analyze_input(data: Dict[str, Any]): # 检查 history_chat 是否为空 # if not data["content"].get("history_chat"): # 使用 get 方法来避免 KeyError dst_data = { "query": "", # 用户输入 "vertical_category": "", # 垂直分类 # 意图分类槽位结构 "intent_category_slot": { "intent_category": "", # 意图分类 "software_name_required": "", # 软件名称 "demand_name_required": "", # 需求内容 "temp_optional": "" # 占位 }, "domain_category": "", # 领域分类 "question_type": "", # 5W2H分类 "retrieve_keywords": "", # 检索语义 "dialogue_status": None, # 对话状态(单轮/多轮 1/2) "response": "", # 回复 # 多轮对话状态 "multi_dialogue_state": { "status": None, # 0: rewrite 1: response "multi_intent_category_slot": { "multi_intent_category": "", # 意图分类 "multi_software_name_required": "", # 软件名称 "multi_demand_name_required": "", # 需求内容 "multi_temp_optional": "" # 占位 }, "multi_domain_category": "", # 领域分类 "question_type": "", # 5W2H分类 "retrieve_keywords": "" # 检索语义 }, "rewrite_result": "" # 改写结果 } input_str = data["input_str"] software_name = get_software_name(data) if judge_define_suffix(input_str) == True: dst_data["query"] = input_str dst_data["dialogue_status"] = 1 dst_data["multi_dialogue_state"]["status"] = 1 dst_data["vertical_category"] = "软件查询" dst_data["intent_category_slot"]["intent_category"] = "查询功能" dst_data["domain_category"] = "后缀名查询" suffix_name = match_suffix(input_str) dst_data["retrieve_keywords"] = suffix_name dst_data["question_type"] = "" suffix_to_software = retrieve_relevant_software(dst_data["retrieve_keywords"]) # 检索出后缀名指定的软件 if isinstance(suffix_to_software, int): dst_data["rewrite_result"] = "未查到相关知识" elif isinstance(suffix_to_software, str): dst_data["rewrite_result"] = f"{suffix_name}是什么文件?用什么软件打开?" result = chain_suffix_answers.invoke({"query": input_str, "kg": suffix_to_software}) dst_data["response"] = result elif isinstance(suffix_to_software, list): suffix_to_software_str = '\n'.join(suffix_to_software) dst_data["rewrite_result"] = f"{suffix_name}是什么文件?用什么软件打开?" result = chain_suffix_answers.invoke({"query": input_str, "kg": suffix_to_software_str}) dst_data["response"] = result else: # 输入前预处理 input_str_stoped = stop_word_processing(input_str, nlp_trf, polite_words) # 停用词处理 input_str_syn = normalize_text(input_str_stoped, synonym_dict_2) # 同义词替换 dst_data["query"] = input_str_syn # 垂直分类 input_str_name_rewrite = chain_full_name_extension.invoke( {"query": input_str_syn, "soft_name": software_name}) vertical_category = chain_vertical.invoke(input_str_name_rewrite) # print(vertical_category) if vertical_category == '闲聊': dst_data["response"] = "闲聊服务只提供给内测用户" else: dst_data["vertical_category"] = vertical_category dst_data["intent_category_slot"]["intent_category"] = chain_intention.invoke(input_str_syn) dst_data["question_type"] = chain_5W2H.invoke(input_str_syn) if software_name: dst_data["dialogue_status"] = 1 dst_data["domain_category"] = software_name dst_data["intent_category_slot"]["software_name_required"] = software_name # 内容扩写(软件全名) # input_str_name_rewrite = chain_full_name_extension.invoke( # {"query": input_str_syn, "soft_name": software_name}) _, demand = get_keywords_v4(input_str_name_rewrite) dst_data["intent_category_slot"]["demand_name_required"] = demand # print(demand) retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3], domain_mapping[dst_data["domain_category"]][3]) index_keywords = interface_search(demand, *retrievers) dst_data["retrieve_keywords"] = index_keywords # print(index_keywords) result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite, "question_type": dst_data["question_type"], "intention_type": dst_data["intent_category_slot"][ "intent_category"], "keywords": index_keywords}) dst_data["rewrite_result"] = result_rewrite else: dst_data["domain_category"] = chain_domain_v3.invoke(input_str) if dst_data["domain_category"] != '其他': dst_data["dialogue_status"] = 1 dst_data["intent_category_slot"]["software_name_required"] = software_name # 内容扩写(软件全名) input_str_name_rewrite = chain_full_name_extension.invoke( {"query": input_str_stoped, "soft_name": software_name}) print(input_str_name_rewrite) _, demand = get_keywords_v4(input_str_name_rewrite) if demand: dst_data["intent_category_slot"]["demand_name_required"] = demand # print(demand) retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3], domain_mapping[dst_data["domain_category"]][3]) index_keywords = interface_search(demand, *retrievers) dst_data["retrieve_keywords"] = index_keywords # print(index_keywords) result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite, "question_type": dst_data["question_type"], "intention_type": dst_data["intent_category_slot"][ "intent_category"], "keywords": index_keywords}) dst_data["rewrite_result"] = result_rewrite else: dst_data["dialogue_status"] = 2 dst_data["response"] = f"好的,具体软件是【{software_name}】, 请补充具体的需求" else: dst_data["dialogue_status"] = 2 dst_data["response"] = f"好的,具体需求是【{input_str_syn}】, 但当前缺少具体软件名字,请补充" return dst_data # else: # # 如果 history_chat 有数据,进行其他处理 # _, B, _ = _chat_slice(data) # if chains_mutil12.invoke(B) == '1': # text = chat_record_conversion_demand(data) # else: # text = chat_record_conversion_software(data) # software_name = chain_domain_v3.invoke(text) # dst_data = { # "query": "", # 用户输入 # "vertical_category": "", # 垂直分类 # # 意图分类槽位结构 # "intent_category_slot": { # "intent_category": "", # 意图分类 # "software_name_required": "", # 软件名称 # "demand_name_required": "", # 需求内容 # "temp_optional": "" # 占位 # }, # "domain_category": "", # 领域分类 # "question_type": "", # 5W2H分类 # "retrieve_keywords": "", # 检索语义 # "dialogue_status": None, # 对话状态(单轮/多轮 1/2) # "response": "", # 回复 # # 多轮对话状态 # "multi_dialogue_state": { # "status": None, # 0: rewrite 1: response # "multi_intent_category_slot": { # "multi_intent_category": "", # 意图分类 # "multi_software_name_required": "", # 软件名称 # "multi_demand_name_required": "", # 需求内容 # "multi_temp_optional": "" # 占位 # }, # "multi_domain_category": "", # 领域分类 # "question_type": "", # 5W2H分类 # "retrieve_keywords": "" # 检索语义 # }, # "rewrite_result": "" # 改写结果 # } # dst_data["query"] = text # dst_data["dialogue_status"] = 1 # vertical_category = chain_vertical.invoke(dst_data["query"]) # if vertical_category == '闲聊': # dst_data["response"] = "闲聊服务只提供给内测用户" # else: # vertical_category = chain_vertical.invoke(dst_data["query"]) # dst_data["intent_category_slot"]["intent_category"] = chain_intention.invoke(dst_data["query"]) # dst_data["question_type"] = chain_5W2H.invoke(dst_data["query"]) # dst_data["domain_category"] = software_name # dst_data["intent_category_slot"]["software_name_required"] = software_name # input_str_name_rewrite = chain_full_name_extension.invoke({"query": dst_data["query"], "soft_name": software_name}) # _, demand = get_keywords_v4(input_str_name_rewrite) # dst_data["intent_category_slot"]["demand_name_required"] = demand # retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3], # domain_mapping[dst_data["domain_category"]][3]) # index_keywords = interface_search(demand, *retrievers) # dst_data["retrieve_keywords"] = index_keywords # result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite, # "question_type": dst_data["question_type"], # "intention_type": dst_data["intent_category_slot"][ # "intent_category"], # "keywords": index_keywords}) # dst_data["rewrite_result"] = result_rewrite # return dst_data # 可选:本地调试入口 if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=3334)