c152fb8714
上传文件
429 lines
20 KiB
Python
429 lines
20 KiB
Python
"""
|
||
===================================
|
||
@Auther:WenZ
|
||
@Company: BooWay
|
||
@project:booway_dm
|
||
===================================
|
||
"""
|
||
|
||
"""
|
||
===================================
|
||
@Auther:WenZ
|
||
@Company: BooWay
|
||
@project:booway_dm
|
||
===================================
|
||
"""
|
||
from utils import judge_define_suffix, match_suffix, retrieve_relevant_software
|
||
from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, \
|
||
ChatRecord
|
||
from utils import stop_word_processing
|
||
import spacy
|
||
import zh_core_web_sm, zh_core_web_md, zh_core_web_lg, zh_core_web_trf
|
||
from utils import get_keywords, get_keywords_v2, get_keywords_v3, get_keywords_v4
|
||
from vector_load import interface_search
|
||
|
||
# nlp_sm = zh_core_web_sm.load()
|
||
# nlp_md = zh_core_web_md.load()
|
||
# nlp_lg = zh_core_web_lg.load()
|
||
nlp_trf = zh_core_web_trf.load()
|
||
|
||
polite_words = {"你好", "您好", "请", "请问", "谢谢", "不客气", "麻烦", "打扰", "拜托", "辛苦", "劳驾"}
|
||
|
||
from chains_ceshi import suffix_answers
|
||
from chains_ceshi import Vertical_classification
|
||
from chains_ceshi import intention_judge
|
||
from chains_ceshi import domain_judge
|
||
from chains_ceshi import judge_5W2H
|
||
|
||
from chains_rewrite import software_name_rewrite
|
||
from chains_rewrite import query_function_rewrite
|
||
from chains_rewrite import operation_guidance_rewrite
|
||
from chains_rewrite import troubleshooting_rewrite
|
||
from chains_rewrite import access_rewrite
|
||
|
||
from kg_management import retriever_txt_faiss1
|
||
from kg_management import retriever_txt_faiss2
|
||
from kg_management import retriever_txt_faiss3
|
||
from kg_management import retriever_txt_faiss4
|
||
from kg_management import retriever_txt_faiss5
|
||
from kg_management import retriever_txt_faiss6
|
||
from kg_management import retriever_txt_faiss7
|
||
from kg_management import retriever_txt_faiss8
|
||
from kg_management import retriever_txt_faiss9
|
||
from kg_management import retriever_txt_faiss10
|
||
from kg_management import retriever_txt_faiss11
|
||
from kg_management import retriever_txt_faiss12
|
||
|
||
from kg_management import input_index_csv_path
|
||
from kg_management import xizang_input_csv_path
|
||
from kg_management import cuceng_input_csv_path
|
||
from kg_management import jigai_input_csv_path
|
||
from kg_management import peiwang_input_path
|
||
|
||
from kg_management import process_domain_category
|
||
|
||
domain_mapping = {
|
||
'西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path),
|
||
'新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path),
|
||
'技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path),
|
||
'博微配网工程计价通D3软件': (retriever_txt_faiss10, retriever_txt_faiss11, retriever_txt_faiss12, peiwang_input_path)
|
||
}
|
||
|
||
chain_suffix_answers = suffix_answers() # 后缀名问题处理
|
||
chain_vertical = Vertical_classification() # 垂直/开放分类
|
||
chain_intention = intention_judge() # 意图分类
|
||
chain_domain = domain_judge() # 领域分类
|
||
chain_5W2H = judge_5W2H() # 5W2H分类
|
||
|
||
chains_name_rewrite = software_name_rewrite() # 问题改写:软件名改写
|
||
chain_function_rewrite = query_function_rewrite() # 问题改写:软件功能查询
|
||
chain_guidance_rewrite = operation_guidance_rewrite() # 问题改写:软件操作指导
|
||
chain_troubleshooting_rewrite = troubleshooting_rewrite() # 问题改写:软件故障排查类
|
||
chain_access_rewrite = access_rewrite() # 问题改写:软件下载与安装
|
||
|
||
from chains_rewrite import to_normal_rewrite
|
||
from chains_rewrite import retrieval_rewrite
|
||
|
||
chain_normal_rewrite = to_normal_rewrite()
|
||
chain_retrieval_rewrite = retrieval_rewrite()
|
||
|
||
from chains_rewrite import full_name_extension
|
||
|
||
chain_full_name_extension = full_name_extension()
|
||
|
||
from chains_ceshi import domain_judge_v3
|
||
|
||
chain_domain_v3 = domain_judge_v3()
|
||
|
||
# 同义词替换
|
||
from utils import normalize_text
|
||
|
||
synonym_dict = {
|
||
"西藏造价软件Z1": ["西藏造价Z1", "Z1软件", "西藏Z1", "西藏工程造价Z1", "西藏造价通Z1", "Z1",
|
||
"西藏造价z1", "z1软件", "西藏z1", "西藏工程造价z1", "西藏造价通z1", "z1",],
|
||
"新型储能计价通C1": ["储能计价通C1", "C1储能计价", "新型储能C1", "储能计价C1", "新型储能计价通", "C1",
|
||
"储能计价通c1", "c1储能计价", "新型储能c1", "储能计价c1", "新型储能计价通", "c1"],
|
||
"技改检修计价通T1": ["技改检修T1", "T1检修计价", "技改计价通T1", "技改检修计价通", "T1技改检修", "T1",
|
||
"技改检修t1", "t1检修计价", "技改计价通t1", "技改检修计价通", "t1技改检修", "t1"],
|
||
"博微配网工程计价通D3软件": ["博微配网D3", "D3配网计价", "配网计价D3", "配网D3", "D3软件", "博微D3", "D3",
|
||
"博微配网工程计价通", "博微配网工程计价通D3",
|
||
"博微配网d3", "d3配网计价", "配网计价d3", "配网d3", "d3软件", "博微d3", "d3"]
|
||
}
|
||
|
||
synonym_dict_2 = {
|
||
"费率": ["费费率"],
|
||
"下载": ["获取", "安装", "下载下来", "装上"]
|
||
}
|
||
|
||
from typing import Optional
|
||
|
||
dst_data = {
|
||
"query": "", # 用户输入
|
||
"vertical_category": "", # 垂直分类
|
||
|
||
# 意图分类槽位结构
|
||
"intent_category_slot": {
|
||
"intent_category": "", # 意图分类
|
||
"software_name_required": "", # 软件名称
|
||
"demand_name_required": "", # 需求内容
|
||
"temp_optional": "" # 占位
|
||
},
|
||
|
||
"domain_category": "", # 领域分类
|
||
"question_type": "", # 5W2H分类
|
||
"retrieve_keywords": "", # 检索语义
|
||
"dialogue_status": None, # 对话状态(单轮/多轮 1/2)
|
||
|
||
"response": "", # 回复
|
||
# 多轮对话状态
|
||
"multi_dialogue_state": {
|
||
"status": None, # 0: rewrite 1: response
|
||
"multi_intent_category_slot": {
|
||
"multi_intent_category": "", # 意图分类
|
||
"multi_software_name_required": "", # 软件名称
|
||
"multi_demand_name_required": "", # 需求内容
|
||
"multi_temp_optional": "" # 占位
|
||
},
|
||
"multi_domain_category": "", # 领域分类
|
||
"question_type": "", # 5W2H分类
|
||
"retrieve_keywords": "" # 检索语义
|
||
},
|
||
|
||
"rewrite_result": "" # 改写结果
|
||
}
|
||
|
||
def get_software_name(input_json):
|
||
try:
|
||
# 尝试访问 "soft_info" 下的 "软件名称"
|
||
software_name = input_json["content"]["soft_info"]["software_name"]
|
||
return software_name
|
||
except KeyError:
|
||
# 如果 "soft_info" 或 "软件名称" 不存在,返回 None
|
||
return None
|
||
|
||
def chat_record_conversion_demand(data):
|
||
|
||
A, B, C = _chat_slice(data)
|
||
C = normalize_text(C, synonym_dict)
|
||
C = chain_software_judge.invoke(C)
|
||
|
||
def _chat_aggregation(A, B, C):
|
||
full_dialog = f"user: {A}\nassistant: {B}\nuser: {C}"
|
||
return full_dialog
|
||
|
||
full_dialog = _chat_aggregation(A, B, C)
|
||
|
||
return chain_mutil_text_rewrite.invoke(full_dialog)
|
||
|
||
def _chat_slice(data):
|
||
history_chat = data["content"].get("history_chat")
|
||
input_str = data["input_str"]
|
||
|
||
A = history_chat[0].get("user", "")
|
||
B = history_chat[1].get("assistant", "")
|
||
C = input_str
|
||
return A, B, C
|
||
|
||
def chat_record_conversion_software(data):
|
||
|
||
A, B, C = _chat_slice(data)
|
||
A = normalize_text(A, synonym_dict)
|
||
A = chain_software_judge.invoke(A)
|
||
|
||
def _chat_aggregation(A, B, C):
|
||
full_dialog = f"user: {A}\nassistant: {B}\nuser: {C}"
|
||
return full_dialog
|
||
|
||
full_dialog = _chat_aggregation(A, B, C)
|
||
|
||
return chain_mutil_text_rewrite.invoke(full_dialog)
|
||
|
||
from chains_rewrite import software_judge
|
||
chain_software_judge = software_judge()
|
||
from chains_rewrite import mutil_text_rewrite
|
||
chain_mutil_text_rewrite = mutil_text_rewrite()
|
||
|
||
from chains_ceshi import mutil12
|
||
|
||
chains_mutil12 = mutil12()
|
||
|
||
|
||
from fastapi import FastAPI, Request
|
||
from pydantic import BaseModel
|
||
from typing import List, Union
|
||
import uvicorn
|
||
|
||
app = FastAPI()
|
||
|
||
from typing import Dict, List, Optional, Any
|
||
|
||
|
||
@app.post("/analyze")
|
||
async def analyze_input(data: Dict[str, Any]):
|
||
# 检查 history_chat 是否为空
|
||
# if not data["content"].get("history_chat"): # 使用 get 方法来避免 KeyError
|
||
dst_data = {
|
||
"query": "", # 用户输入
|
||
"vertical_category": "", # 垂直分类
|
||
|
||
# 意图分类槽位结构
|
||
"intent_category_slot": {
|
||
"intent_category": "", # 意图分类
|
||
"software_name_required": "", # 软件名称
|
||
"demand_name_required": "", # 需求内容
|
||
"temp_optional": "" # 占位
|
||
},
|
||
|
||
"domain_category": "", # 领域分类
|
||
"question_type": "", # 5W2H分类
|
||
"retrieve_keywords": "", # 检索语义
|
||
"dialogue_status": None, # 对话状态(单轮/多轮 1/2)
|
||
|
||
"response": "", # 回复
|
||
# 多轮对话状态
|
||
"multi_dialogue_state": {
|
||
"status": None, # 0: rewrite 1: response
|
||
"multi_intent_category_slot": {
|
||
"multi_intent_category": "", # 意图分类
|
||
"multi_software_name_required": "", # 软件名称
|
||
"multi_demand_name_required": "", # 需求内容
|
||
"multi_temp_optional": "" # 占位
|
||
},
|
||
"multi_domain_category": "", # 领域分类
|
||
"question_type": "", # 5W2H分类
|
||
"retrieve_keywords": "" # 检索语义
|
||
},
|
||
|
||
"rewrite_result": "" # 改写结果
|
||
}
|
||
|
||
input_str = data["input_str"]
|
||
software_name = get_software_name(data)
|
||
|
||
if judge_define_suffix(input_str) == True:
|
||
dst_data["query"] = input_str
|
||
dst_data["dialogue_status"] = 1
|
||
dst_data["multi_dialogue_state"]["status"] = 1
|
||
dst_data["vertical_category"] = "软件查询"
|
||
dst_data["intent_category_slot"]["intent_category"] = "查询功能"
|
||
dst_data["domain_category"] = "后缀名查询"
|
||
suffix_name = match_suffix(input_str)
|
||
dst_data["retrieve_keywords"] = suffix_name
|
||
dst_data["question_type"] = ""
|
||
suffix_to_software = retrieve_relevant_software(dst_data["retrieve_keywords"]) # 检索出后缀名指定的软件
|
||
if isinstance(suffix_to_software, int):
|
||
dst_data["rewrite_result"] = "未查到相关知识"
|
||
elif isinstance(suffix_to_software, str):
|
||
dst_data["rewrite_result"] = f"{suffix_name}是什么文件?用什么软件打开?"
|
||
result = chain_suffix_answers.invoke({"query": input_str, "kg": suffix_to_software})
|
||
dst_data["response"] = result
|
||
elif isinstance(suffix_to_software, list):
|
||
suffix_to_software_str = '\n'.join(suffix_to_software)
|
||
dst_data["rewrite_result"] = f"{suffix_name}是什么文件?用什么软件打开?"
|
||
result = chain_suffix_answers.invoke({"query": input_str, "kg": suffix_to_software_str})
|
||
dst_data["response"] = result
|
||
else:
|
||
# 输入前预处理
|
||
input_str_stoped = stop_word_processing(input_str, nlp_trf, polite_words) # 停用词处理
|
||
input_str_syn = normalize_text(input_str_stoped, synonym_dict_2) # 同义词替换
|
||
dst_data["query"] = input_str_syn
|
||
|
||
# 垂直分类
|
||
input_str_name_rewrite = chain_full_name_extension.invoke(
|
||
{"query": input_str_syn, "soft_name": software_name})
|
||
vertical_category = chain_vertical.invoke(input_str_name_rewrite)
|
||
# print(vertical_category)
|
||
if vertical_category == '闲聊':
|
||
dst_data["response"] = "闲聊服务只提供给内测用户"
|
||
else:
|
||
dst_data["vertical_category"] = vertical_category
|
||
dst_data["intent_category_slot"]["intent_category"] = chain_intention.invoke(input_str_syn)
|
||
dst_data["question_type"] = chain_5W2H.invoke(input_str_syn)
|
||
if software_name:
|
||
dst_data["dialogue_status"] = 1
|
||
dst_data["domain_category"] = software_name
|
||
dst_data["intent_category_slot"]["software_name_required"] = software_name
|
||
# 内容扩写(软件全名)
|
||
# input_str_name_rewrite = chain_full_name_extension.invoke(
|
||
# {"query": input_str_syn, "soft_name": software_name})
|
||
_, demand = get_keywords_v4(input_str_name_rewrite)
|
||
dst_data["intent_category_slot"]["demand_name_required"] = demand
|
||
# print(demand)
|
||
retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3],
|
||
domain_mapping[dst_data["domain_category"]][3])
|
||
index_keywords = interface_search(demand, *retrievers)
|
||
dst_data["retrieve_keywords"] = index_keywords
|
||
# print(index_keywords)
|
||
result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite,
|
||
"question_type": dst_data["question_type"],
|
||
"intention_type": dst_data["intent_category_slot"][
|
||
"intent_category"],
|
||
"keywords": index_keywords})
|
||
dst_data["rewrite_result"] = result_rewrite
|
||
else:
|
||
dst_data["domain_category"] = chain_domain_v3.invoke(input_str)
|
||
if dst_data["domain_category"] != '其他':
|
||
dst_data["dialogue_status"] = 1
|
||
dst_data["intent_category_slot"]["software_name_required"] = software_name
|
||
# 内容扩写(软件全名)
|
||
input_str_name_rewrite = chain_full_name_extension.invoke(
|
||
{"query": input_str_stoped, "soft_name": software_name})
|
||
print(input_str_name_rewrite)
|
||
_, demand = get_keywords_v4(input_str_name_rewrite)
|
||
if demand:
|
||
dst_data["intent_category_slot"]["demand_name_required"] = demand
|
||
# print(demand)
|
||
retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3],
|
||
domain_mapping[dst_data["domain_category"]][3])
|
||
index_keywords = interface_search(demand, *retrievers)
|
||
dst_data["retrieve_keywords"] = index_keywords
|
||
# print(index_keywords)
|
||
result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite,
|
||
"question_type": dst_data["question_type"],
|
||
"intention_type": dst_data["intent_category_slot"][
|
||
"intent_category"],
|
||
"keywords": index_keywords})
|
||
dst_data["rewrite_result"] = result_rewrite
|
||
else:
|
||
dst_data["dialogue_status"] = 2
|
||
dst_data["response"] = f"好的,具体软件是【{software_name}】, 请补充具体的需求"
|
||
else:
|
||
dst_data["dialogue_status"] = 2
|
||
dst_data["response"] = f"好的,具体需求是【{input_str_syn}】, 但当前缺少具体软件名字,请补充"
|
||
return dst_data
|
||
# else:
|
||
# # 如果 history_chat 有数据,进行其他处理
|
||
# _, B, _ = _chat_slice(data)
|
||
# if chains_mutil12.invoke(B) == '1':
|
||
# text = chat_record_conversion_demand(data)
|
||
# else:
|
||
# text = chat_record_conversion_software(data)
|
||
|
||
# software_name = chain_domain_v3.invoke(text)
|
||
|
||
# dst_data = {
|
||
# "query": "", # 用户输入
|
||
# "vertical_category": "", # 垂直分类
|
||
|
||
# # 意图分类槽位结构
|
||
# "intent_category_slot": {
|
||
# "intent_category": "", # 意图分类
|
||
# "software_name_required": "", # 软件名称
|
||
# "demand_name_required": "", # 需求内容
|
||
# "temp_optional": "" # 占位
|
||
# },
|
||
|
||
# "domain_category": "", # 领域分类
|
||
# "question_type": "", # 5W2H分类
|
||
# "retrieve_keywords": "", # 检索语义
|
||
# "dialogue_status": None, # 对话状态(单轮/多轮 1/2)
|
||
|
||
# "response": "", # 回复
|
||
# # 多轮对话状态
|
||
# "multi_dialogue_state": {
|
||
# "status": None, # 0: rewrite 1: response
|
||
# "multi_intent_category_slot": {
|
||
# "multi_intent_category": "", # 意图分类
|
||
# "multi_software_name_required": "", # 软件名称
|
||
# "multi_demand_name_required": "", # 需求内容
|
||
# "multi_temp_optional": "" # 占位
|
||
# },
|
||
# "multi_domain_category": "", # 领域分类
|
||
# "question_type": "", # 5W2H分类
|
||
# "retrieve_keywords": "" # 检索语义
|
||
# },
|
||
|
||
# "rewrite_result": "" # 改写结果
|
||
# }
|
||
|
||
# dst_data["query"] = text
|
||
# dst_data["dialogue_status"] = 1
|
||
# vertical_category = chain_vertical.invoke(dst_data["query"])
|
||
# if vertical_category == '闲聊':
|
||
# dst_data["response"] = "闲聊服务只提供给内测用户"
|
||
# else:
|
||
# vertical_category = chain_vertical.invoke(dst_data["query"])
|
||
# dst_data["intent_category_slot"]["intent_category"] = chain_intention.invoke(dst_data["query"])
|
||
# dst_data["question_type"] = chain_5W2H.invoke(dst_data["query"])
|
||
# dst_data["domain_category"] = software_name
|
||
# dst_data["intent_category_slot"]["software_name_required"] = software_name
|
||
# input_str_name_rewrite = chain_full_name_extension.invoke({"query": dst_data["query"], "soft_name": software_name})
|
||
# _, demand = get_keywords_v4(input_str_name_rewrite)
|
||
# dst_data["intent_category_slot"]["demand_name_required"] = demand
|
||
# retrievers, input_csv_path = (domain_mapping[dst_data["domain_category"]][:3],
|
||
# domain_mapping[dst_data["domain_category"]][3])
|
||
# index_keywords = interface_search(demand, *retrievers)
|
||
# dst_data["retrieve_keywords"] = index_keywords
|
||
# result_rewrite = chain_retrieval_rewrite.invoke({"query": input_str_name_rewrite,
|
||
# "question_type": dst_data["question_type"],
|
||
# "intention_type": dst_data["intent_category_slot"][
|
||
# "intent_category"],
|
||
# "keywords": index_keywords})
|
||
# dst_data["rewrite_result"] = result_rewrite
|
||
# return dst_data
|
||
|
||
# 可选:本地调试入口
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host="0.0.0.0", port=3334)
|
||
|