Files
DM_rewrite_3.31/fast_api_main_4.3.py
2025-04-03 17:23:53 +08:00

429 lines
20 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
===================================
@AutherWenZ
@Company: BooWay
@projectbooway_dm
===================================
"""
"""
===================================
@AutherWenZ
@Company: BooWay
@projectbooway_dm
===================================
"""
from utils import judge_define_suffix, match_suffix, retrieve_relevant_software
from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, \
ChatRecord
from utils import stop_word_processing
import spacy
import zh_core_web_sm, zh_core_web_md, zh_core_web_lg, zh_core_web_trf
from utils import get_keywords, get_keywords_v2, get_keywords_v3, get_keywords_v4
from vector_load import interface_search
# nlp_sm = zh_core_web_sm.load()
# nlp_md = zh_core_web_md.load()
# nlp_lg = zh_core_web_lg.load()
nlp_trf = zh_core_web_trf.load()
polite_words = {"你好", "您好", "请", "请问", "谢谢", "不客气", "麻烦", "打扰", "拜托", "辛苦", "劳驾"}
from chains_ceshi import suffix_answers
from chains_ceshi import Vertical_classification
from chains_ceshi import intention_judge
from chains_ceshi import domain_judge
from chains_ceshi import judge_5W2H
from chains_rewrite import software_name_rewrite
from chains_rewrite import query_function_rewrite
from chains_rewrite import operation_guidance_rewrite
from chains_rewrite import troubleshooting_rewrite
from chains_rewrite import access_rewrite
from kg_management import retriever_txt_faiss1
from kg_management import retriever_txt_faiss2
from kg_management import retriever_txt_faiss3
from kg_management import retriever_txt_faiss4
from kg_management import retriever_txt_faiss5
from kg_management import retriever_txt_faiss6
from kg_management import retriever_txt_faiss7
from kg_management import retriever_txt_faiss8
from kg_management import retriever_txt_faiss9
from kg_management import retriever_txt_faiss10
from kg_management import retriever_txt_faiss11
from kg_management import retriever_txt_faiss12
from kg_management import input_index_csv_path
from kg_management import xizang_input_csv_path
from kg_management import cuceng_input_csv_path
from kg_management import jigai_input_csv_path
from kg_management import peiwang_input_path
from kg_management import process_domain_category
domain_mapping = {
'西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path),
'新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path),
'技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path),
'博微配网工程计价通D3软件': (retriever_txt_faiss10, retriever_txt_faiss11, retriever_txt_faiss12, peiwang_input_path)
}
chain_suffix_answers = suffix_answers() # 后缀名问题处理
chain_vertical = Vertical_classification() # 垂直/开放分类
chain_intention = intention_judge() # 意图分类
chain_domain = domain_judge() # 领域分类
chain_5W2H = judge_5W2H() # 5W2H分类
chains_name_rewrite = software_name_rewrite() # 问题改写:软件名改写
chain_function_rewrite = query_function_rewrite() # 问题改写:软件功能查询
chain_guidance_rewrite = operation_guidance_rewrite() # 问题改写:软件操作指导
chain_troubleshooting_rewrite = troubleshooting_rewrite() # 问题改写:软件故障排查类
chain_access_rewrite = access_rewrite() # 问题改写:软件下载与安装
from chains_rewrite import to_normal_rewrite
from chains_rewrite import retrieval_rewrite
chain_normal_rewrite = to_normal_rewrite()
chain_retrieval_rewrite = retrieval_rewrite()
from chains_rewrite import full_name_extension
chain_full_name_extension = full_name_extension()
from chains_ceshi import domain_judge_v3
chain_domain_v3 = domain_judge_v3()
# 同义词替换
from utils import normalize_text
synonym_dict = {
"西藏造价软件Z1": ["西藏造价Z1", "Z1软件", "西藏Z1", "西藏工程造价Z1", "西藏造价通Z1", "Z1",
"西藏造价z1", "z1软件", "西藏z1", "西藏工程造价z1", "西藏造价通z1", "z1",],
"新型储能计价通C1": ["储能计价通C1", "C1储能计价", "新型储能C1", "储能计价C1", "新型储能计价通", "C1",
"储能计价通c1", "c1储能计价", "新型储能c1", "储能计价c1", "新型储能计价通", "c1"],
"技改检修计价通T1": ["技改检修T1", "T1检修计价", "技改计价通T1", "技改检修计价通", "T1技改检修", "T1",
"技改检修t1", "t1检修计价", "技改计价通t1", "技改检修计价通", "t1技改检修", "t1"],
"博微配网工程计价通D3软件": ["博微配网D3", "D3配网计价", "配网计价D3", "配网D3", "D3软件", "博微D3", "D3",
"博微配网工程计价通", "博微配网工程计价通D3",
"博微配网d3", "d3配网计价", "配网计价d3", "配网d3", "d3软件", "博微d3", "d3"]
}
synonym_dict_2 = {
"费率": ["费费率"],
"下载": ["获取", "安装", "下载下来", "装上"]
}
from typing import Optional
dst_data = {
"query": "", # 用户输入
"vertical_category": "", # 垂直分类
# 意图分类槽位结构
"intent_category_slot": {
"intent_category": "", # 意图分类
"software_name_required": "", # 软件名称
"demand_name_required": "", # 需求内容
"temp_optional": "" # 占位
},
"domain_category": "", # 领域分类
"question_type": "", # 5W2H分类
"retrieve_keywords": "", # 检索语义
"dialogue_status": None, # 对话状态(单轮/多轮 1/2)
"response": "", # 回复
# 多轮对话状态
"multi_dialogue_state": {
"status": None, # 0: rewrite 1: response
"multi_intent_category_slot": {
"multi_intent_category": "", # 意图分类
"multi_software_name_required": "", # 软件名称
"multi_demand_name_required": "", # 需求内容
"multi_temp_optional": "" # 占位
},
"multi_domain_category": "", # 领域分类
"question_type": "", # 5W2H分类
"retrieve_keywords": "" # 检索语义
},
"rewrite_result": "" # 改写结果
}
def get_software_name(input_json):
try:
# 尝试访问 "soft_info" 下的 "软件名称"
software_name = input_json["content"]["soft_info"]["software_name"]
return software_name
except KeyError:
# 如果 "soft_info" 或 "软件名称" 不存在,返回 None
return None
def chat_record_conversion_demand(data):
A, B, C = _chat_slice(data)
C = normalize_text(C, synonym_dict)
C = chain_software_judge.invoke(C)
def _chat_aggregation(A, B, C):
full_dialog = f"user: {A}\nassistant: {B}\nuser: {C}"
return full_dialog
full_dialog = _chat_aggregation(A, B, C)
return chain_mutil_text_rewrite.invoke(full_dialog)
def _chat_slice(data):
history_chat = data["content"].get("history_chat")
input_str = data["input_str"]
A = history_chat[0].get("user", "")
B = history_chat[1].get("assistant", "")
C = input_str
return A, B, C
def chat_record_conversion_software(data):
A, B, C = _chat_slice(data)
A = normalize_text(A, synonym_dict)
A = chain_software_judge.invoke(A)
def _chat_aggregation(A, B, C):
full_dialog = f"user: {A}\nassistant: {B}\nuser: {C}"
return full_dialog
full_dialog = _chat_aggregation(A, B, C)
return chain_mutil_text_rewrite.invoke(full_dialog)
from chains_rewrite import software_judge
chain_software_judge = software_judge()
from chains_rewrite import mutil_text_rewrite
chain_mutil_text_rewrite = mutil_text_rewrite()
from chains_ceshi import mutil12
chains_mutil12 = mutil12()
from fastapi import FastAPI, Request
from pydantic import BaseModel
from typing import List, Union
import uvicorn
app = FastAPI()
from typing import Dict, List, Optional, Any
@app.post("/analyze")
async def analyze_input(data: Dict[str, Any]):
# 检查 history_chat 是否为空
# if not data["content"].get("history_chat"): # 使用 get 方法来避免 KeyError
dst_data = {
"query": "", # 用户输入
"vertical_category": "", # 垂直分类
# 意图分类槽位结构
"intent_category_slot": {
"intent_category": "", # 意图分类
"software_name_required": "", # 软件名称
"demand_name_required": "", # 需求内容
"temp_optional": "" # 占位
},
"domain_category": "", # 领域分类
"question_type": "", # 5W2H分类
"retrieve_keywords": "", # 检索语义
"dialogue_status": None, # 对话状态(单轮/多轮 1/2)
"response": "", # 回复
# 多轮对话状态
"multi_dialogue_state": {
"status": None, # 0: rewrite 1: response
"multi_intent_category_slot": {
"multi_intent_category": "", # 意图分类
"multi_software_name_required": "", # 软件名称
"multi_demand_name_required": "", # 需求内容
"multi_temp_optional": "" # 占位
},
"multi_domain_category": "", # 领域分类
"question_type": "", # 5W2H分类
"retrieve_keywords": "" # 检索语义
},
"rewrite_result": "" # 改写结果
}
input_str = data["input_str"]
software_name = get_software_name(data)
if judge_define_suffix(input_str) == True:
dst_data["query"] = input_str
dst_data["dialogue_status"] = 1
dst_data["multi_dialogue_state"]["status"] = 1
dst_data["vertical_category"] = "软件查询"
dst_data["intent_category_slot"]["intent_category"] = "查询功能"
dst_data["domain_category"] = "后缀名查询"
suffix_name = match_suffix(input_str)
dst_data["retrieve_keywords"] = suffix_name
dst_data["question_type"] = ""
suffix_to_software = retrieve_relevant_software(dst_data["retrieve_keywords"]) # 检索出后缀名指定的软件
if isinstance(suffix_to_software, int):
dst_data["rewrite_result"] = "未查到相关知识"
elif isinstance(suffix_to_software, str):
dst_data["rewrite_result"] = f"{suffix_name}是什么文件?用什么软件打开?"
result = chain_suffix_answers.invoke({"query": input_str, "kg": suffix_to_software})
dst_data["response"] = result
elif isinstance(suffix_to_software, list):
suffix_to_software_str = '\n'.join(suffix_to_software)
dst_data["rewrite_result"] = f"{suffix_name}是什么文件?用什么软件打开?"
result = chain_suffix_answers.invoke({"query": input_str, "kg": suffix_to_software_str})
dst_data["response"] = result
else:
# 输入前预处理
input_str_stoped = stop_word_processing(input_str, nlp_trf, polite_words) # 停用词处理
input_str_syn = normalize_text(input_str_stoped, synonym_dict_2) # 同义词替换
dst_data["query"] = input_str_syn
# 垂直分类
input_str_name_rewrite = chain_full_name_extension.invoke(
{"query": input_str_syn, "soft_name": software_name})
vertical_category = chain_vertical.invoke(input_str_name_rewrite)
# print(vertical_category)
if vertical_category == '闲聊':
dst_data["response"] = "闲聊服务只提供给内测用户"
else:
dst_data["vertical_category"] = vertical_category
dst_data["intent_category_slot"]["intent_category"] = chain_intention.invoke(input_str_syn)
dst_data["question_type"] = chain_5W2H.invoke(input_str_syn)
if software_name:
dst_data["dialogue_status"] = 1
dst_data["domain_category"] = software_name
dst_data["intent_category_slot"]["software_name_required"] = software_name
# 内容扩写(软件全名)
# input_str_name_rewrite = chain_full_name_extension.invoke(
# {"query": input_str_syn, "soft_name": software_name})
_, demand = get_keywords_v4(input_str_name_rewrite)
dst_data["intent_category_slot"]["demand_name_required"] = demand
# print(demand)
retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3],
domain_mapping[dst_data["domain_category"]][3])
index_keywords = interface_search(demand, *retrievers)
dst_data["retrieve_keywords"] = index_keywords
# print(index_keywords)
result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite,
"question_type": dst_data["question_type"],
"intention_type": dst_data["intent_category_slot"][
"intent_category"],
"keywords": index_keywords})
dst_data["rewrite_result"] = result_rewrite
else:
dst_data["domain_category"] = chain_domain_v3.invoke(input_str)
if dst_data["domain_category"] != '其他':
dst_data["dialogue_status"] = 1
dst_data["intent_category_slot"]["software_name_required"] = software_name
# 内容扩写(软件全名)
input_str_name_rewrite = chain_full_name_extension.invoke(
{"query": input_str_stoped, "soft_name": software_name})
print(input_str_name_rewrite)
_, demand = get_keywords_v4(input_str_name_rewrite)
if demand:
dst_data["intent_category_slot"]["demand_name_required"] = demand
# print(demand)
retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3],
domain_mapping[dst_data["domain_category"]][3])
index_keywords = interface_search(demand, *retrievers)
dst_data["retrieve_keywords"] = index_keywords
# print(index_keywords)
result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite,
"question_type": dst_data["question_type"],
"intention_type": dst_data["intent_category_slot"][
"intent_category"],
"keywords": index_keywords})
dst_data["rewrite_result"] = result_rewrite
else:
dst_data["dialogue_status"] = 2
dst_data["response"] = f"好的,具体软件是【{software_name}】, 请补充具体的需求"
else:
dst_data["dialogue_status"] = 2
dst_data["response"] = f"好的,具体需求是【{input_str_syn}】, 但当前缺少具体软件名字,请补充"
return dst_data
# else:
# # 如果 history_chat 有数据,进行其他处理
# _, B, _ = _chat_slice(data)
# if chains_mutil12.invoke(B) == '1':
# text = chat_record_conversion_demand(data)
# else:
# text = chat_record_conversion_software(data)
# software_name = chain_domain_v3.invoke(text)
# dst_data = {
# "query": "", # 用户输入
# "vertical_category": "", # 垂直分类
# # 意图分类槽位结构
# "intent_category_slot": {
# "intent_category": "", # 意图分类
# "software_name_required": "", # 软件名称
# "demand_name_required": "", # 需求内容
# "temp_optional": "" # 占位
# },
# "domain_category": "", # 领域分类
# "question_type": "", # 5W2H分类
# "retrieve_keywords": "", # 检索语义
# "dialogue_status": None, # 对话状态(单轮/多轮 1/2)
# "response": "", # 回复
# # 多轮对话状态
# "multi_dialogue_state": {
# "status": None, # 0: rewrite 1: response
# "multi_intent_category_slot": {
# "multi_intent_category": "", # 意图分类
# "multi_software_name_required": "", # 软件名称
# "multi_demand_name_required": "", # 需求内容
# "multi_temp_optional": "" # 占位
# },
# "multi_domain_category": "", # 领域分类
# "question_type": "", # 5W2H分类
# "retrieve_keywords": "" # 检索语义
# },
# "rewrite_result": "" # 改写结果
# }
# dst_data["query"] = text
# dst_data["dialogue_status"] = 1
# vertical_category = chain_vertical.invoke(dst_data["query"])
# if vertical_category == '闲聊':
# dst_data["response"] = "闲聊服务只提供给内测用户"
# else:
# vertical_category = chain_vertical.invoke(dst_data["query"])
# dst_data["intent_category_slot"]["intent_category"] = chain_intention.invoke(dst_data["query"])
# dst_data["question_type"] = chain_5W2H.invoke(dst_data["query"])
# dst_data["domain_category"] = software_name
# dst_data["intent_category_slot"]["software_name_required"] = software_name
# input_str_name_rewrite = chain_full_name_extension.invoke({"query": dst_data["query"], "soft_name": software_name})
# _, demand = get_keywords_v4(input_str_name_rewrite)
# dst_data["intent_category_slot"]["demand_name_required"] = demand
# retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3],
# domain_mapping[dst_data["domain_category"]][3])
# index_keywords = interface_search(demand, *retrievers)
# dst_data["retrieve_keywords"] = index_keywords
# result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite,
# "question_type": dst_data["question_type"],
# "intention_type": dst_data["intent_category_slot"][
# "intent_category"],
# "keywords": index_keywords})
# dst_data["rewrite_result"] = result_rewrite
# return dst_data
# 可选:本地调试入口
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=3334)