d0261c5997
4.3更新
166 lines
6.4 KiB
Python
166 lines
6.4 KiB
Python
"""
|
||
===================================
|
||
@Auther:WenZ
|
||
@Company: BooWay
|
||
@project:booway_dm
|
||
===================================
|
||
"""
|
||
from utils import judge_define_suffix, match_suffix, retrieve_relevant_software
|
||
from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, \
|
||
ChatRecord
|
||
from utils import stop_word_processing
|
||
from utils import get_keywords, get_keywords_v2, get_keywords_v3
|
||
from vector_load import interface_search
|
||
|
||
import spacy
|
||
import zh_core_web_sm, zh_core_web_md, zh_core_web_lg, zh_core_web_trf
|
||
|
||
# nlp_sm = zh_core_web_sm.load()
|
||
# nlp_md = zh_core_web_md.load()
|
||
# nlp_lg = zh_core_web_lg.load()
|
||
nlp_trf = zh_core_web_trf.load()
|
||
|
||
polite_words = {"你好", "您好", "请", "请问", "谢谢", "不客气", "麻烦", "打扰", "拜托", "辛苦", "劳驾"}
|
||
|
||
from chains_ceshi import suffix_answers
|
||
from chains_ceshi import Vertical_classification
|
||
from chains_ceshi import intention_judge
|
||
from chains_ceshi import domain_judge
|
||
from chains_ceshi import judge_5W2H
|
||
|
||
from chains_rewrite import software_name_rewrite
|
||
from chains_rewrite import query_function_rewrite
|
||
from chains_rewrite import operation_guidance_rewrite
|
||
from chains_rewrite import troubleshooting_rewrite
|
||
from chains_rewrite import access_rewrite
|
||
|
||
from kg_management import retriever_txt_faiss1
|
||
from kg_management import retriever_txt_faiss2
|
||
from kg_management import retriever_txt_faiss3
|
||
from kg_management import retriever_txt_faiss4
|
||
from kg_management import retriever_txt_faiss5
|
||
from kg_management import retriever_txt_faiss6
|
||
from kg_management import retriever_txt_faiss7
|
||
from kg_management import retriever_txt_faiss8
|
||
from kg_management import retriever_txt_faiss9
|
||
|
||
from kg_management import input_index_csv_path
|
||
from kg_management import xizang_input_csv_path
|
||
from kg_management import cuceng_input_csv_path
|
||
from kg_management import jigai_input_csv_path
|
||
|
||
from kg_management import process_domain_category
|
||
|
||
domain_mapping = {
|
||
'西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path),
|
||
'新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path),
|
||
'技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path)
|
||
}
|
||
|
||
chain_suffix_answers = suffix_answers() # 后缀名问题处理
|
||
chain_vertical = Vertical_classification() # 垂直/开放分类
|
||
chain_intention = intention_judge() # 意图分类
|
||
chain_domain = domain_judge() # 领域分类
|
||
chain_5W2H = judge_5W2H() # 5W2H分类
|
||
|
||
chains_name_rewrite = software_name_rewrite() # 问题改写:软件名改写
|
||
chain_function_rewrite = query_function_rewrite() # 问题改写:软件功能查询
|
||
chain_guidance_rewrite = operation_guidance_rewrite() # 问题改写:软件操作指导
|
||
chain_troubleshooting_rewrite = troubleshooting_rewrite() # 问题改写:软件故障排查类
|
||
chain_access_rewrite = access_rewrite() # 问题改写:软件下载与安装
|
||
|
||
from chains_rewrite import to_normal_rewrite
|
||
from chains_rewrite import retrieval_rewrite
|
||
|
||
chain_normal_rewrite = to_normal_rewrite()
|
||
chain_retrieval_rewrite = retrieval_rewrite()
|
||
|
||
|
||
from utils import normalize_text
|
||
|
||
synonym_dict = {
|
||
"费率": ["费费率"],
|
||
"下载": ["获取", "安装", "下载下来", "装上"]
|
||
}
|
||
|
||
from fastapi import FastAPI, Request
|
||
from pydantic import BaseModel
|
||
from typing import List, Union
|
||
import uvicorn
|
||
|
||
app = FastAPI()
|
||
|
||
class QueryRequest(BaseModel):
|
||
input_str: str
|
||
|
||
@app.post("/analyze")
|
||
async def analyze_input(data: QueryRequest):
|
||
input_str = data.input_str.strip()
|
||
|
||
if judge_define_suffix(input_str):
|
||
nlu_info = NLUInfo(vertical_category="软件咨询")
|
||
nlu_info.intent_category = "查询功能"
|
||
nlu_info.domain_category = "后缀名查询"
|
||
|
||
suffix_name = match_suffix(input_str)
|
||
nlu_info.retrieve_keywords = suffix_name
|
||
suffix_to_software = retrieve_relevant_software(suffix_name)
|
||
|
||
if isinstance(suffix_to_software, int):
|
||
return {"response": "booway助手:未查到相关知识"}
|
||
|
||
query_rewrite = f"{suffix_name}是什么文件?用什么软件打开?"
|
||
nlu_info.rewrite = query_rewrite
|
||
|
||
if isinstance(suffix_to_software, list):
|
||
query_kg = '\n'.join(suffix_to_software)
|
||
else:
|
||
query_kg = suffix_to_software
|
||
|
||
result = chain_suffix_answers.invoke({"query": input_str, "kg": query_kg})
|
||
# return {"response": f"booway助手:{result}"}
|
||
return {"response": f"booway助手:{result}",
|
||
"nlu_info": nlu_info}
|
||
|
||
else:
|
||
# todo: 多轮对话处理
|
||
input_str_stoped = stop_word_processing(input_str, nlp_trf, polite_words) # 根据实际传参调整
|
||
input_str_syn = normalize_text(input_str_stoped, synonym_dict)
|
||
|
||
vertical_category = chain_domain.invoke(input_str_syn)
|
||
|
||
if vertical_category == "未知":
|
||
return {"response": "booway助手:闲聊服务只提供给内测用户"}
|
||
|
||
nlu_info = NLUInfo(vertical_category="软件咨询")
|
||
nlu_info.domain_category = vertical_category
|
||
|
||
input_str_name_rewrite = chains_name_rewrite.invoke({
|
||
"query": input_str,
|
||
"software_name": nlu_info.domain_category
|
||
})
|
||
input_str_rewrite = chain_normal_rewrite.invoke(input_str_name_rewrite)
|
||
|
||
temp_retriever = get_keywords_v2(input_str_rewrite)
|
||
nlu_info.question_type = chain_5W2H.invoke(input_str_rewrite)
|
||
nlu_info.intent_category = chain_intention.invoke(input_str)
|
||
|
||
retrievers, input_csv_path = domain_mapping[nlu_info.domain_category][:3], domain_mapping[nlu_info.domain_category][3]
|
||
index_keywords = interface_search(temp_retriever, *retrievers)
|
||
|
||
nlu_info.rewrite = chain_retrieval_rewrite.invoke({
|
||
"query": input_str_rewrite,
|
||
"question_type": nlu_info.question_type,
|
||
"intention_type": nlu_info.intent_category,
|
||
"keywords": index_keywords
|
||
})
|
||
|
||
nlu_info.retrieve_keywords = get_keywords_v3(nlu_info.rewrite)
|
||
|
||
return {"response": "success",
|
||
"nlu_info": nlu_info}
|
||
|
||
# 可选:本地调试入口
|
||
if __name__ == "__main__":
|
||
uvicorn.run(app, host="0.0.0.0", port=3333)
|