Files
DM_rewrite_3.31/fast_api_main.py
2025-04-03 17:23:37 +08:00

166 lines
6.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
===================================
@AutherWenZ
@Company: BooWay
@projectbooway_dm
===================================
"""
from utils import judge_define_suffix, match_suffix, retrieve_relevant_software
from dialogue_management import QuestionInfo, DialogType, DialogInfo, QuestionType, NLUInfo, ScenInfo, TalkInfo, \
ChatRecord
from utils import stop_word_processing
from utils import get_keywords, get_keywords_v2, get_keywords_v3
from vector_load import interface_search
import spacy
import zh_core_web_sm, zh_core_web_md, zh_core_web_lg, zh_core_web_trf
# nlp_sm = zh_core_web_sm.load()
# nlp_md = zh_core_web_md.load()
# nlp_lg = zh_core_web_lg.load()
nlp_trf = zh_core_web_trf.load()
polite_words = {"你好", "您好", "请", "请问", "谢谢", "不客气", "麻烦", "打扰", "拜托", "辛苦", "劳驾"}
from chains_ceshi import suffix_answers
from chains_ceshi import Vertical_classification
from chains_ceshi import intention_judge
from chains_ceshi import domain_judge
from chains_ceshi import judge_5W2H
from chains_rewrite import software_name_rewrite
from chains_rewrite import query_function_rewrite
from chains_rewrite import operation_guidance_rewrite
from chains_rewrite import troubleshooting_rewrite
from chains_rewrite import access_rewrite
from kg_management import retriever_txt_faiss1
from kg_management import retriever_txt_faiss2
from kg_management import retriever_txt_faiss3
from kg_management import retriever_txt_faiss4
from kg_management import retriever_txt_faiss5
from kg_management import retriever_txt_faiss6
from kg_management import retriever_txt_faiss7
from kg_management import retriever_txt_faiss8
from kg_management import retriever_txt_faiss9
from kg_management import input_index_csv_path
from kg_management import xizang_input_csv_path
from kg_management import cuceng_input_csv_path
from kg_management import jigai_input_csv_path
from kg_management import process_domain_category
domain_mapping = {
'西藏造价软件Z1': (retriever_txt_faiss1, retriever_txt_faiss2, retriever_txt_faiss3, xizang_input_csv_path),
'新型储能计价通C1': (retriever_txt_faiss4, retriever_txt_faiss5, retriever_txt_faiss6, cuceng_input_csv_path),
'技改检修计价通T1': (retriever_txt_faiss7, retriever_txt_faiss8, retriever_txt_faiss9, jigai_input_csv_path)
}
chain_suffix_answers = suffix_answers() # 后缀名问题处理
chain_vertical = Vertical_classification() # 垂直/开放分类
chain_intention = intention_judge() # 意图分类
chain_domain = domain_judge() # 领域分类
chain_5W2H = judge_5W2H() # 5W2H分类
chains_name_rewrite = software_name_rewrite() # 问题改写:软件名改写
chain_function_rewrite = query_function_rewrite() # 问题改写:软件功能查询
chain_guidance_rewrite = operation_guidance_rewrite() # 问题改写:软件操作指导
chain_troubleshooting_rewrite = troubleshooting_rewrite() # 问题改写:软件故障排查类
chain_access_rewrite = access_rewrite() # 问题改写:软件下载与安装
from chains_rewrite import to_normal_rewrite
from chains_rewrite import retrieval_rewrite
chain_normal_rewrite = to_normal_rewrite()
chain_retrieval_rewrite = retrieval_rewrite()
from utils import normalize_text
synonym_dict = {
"费率": ["费费率"],
"下载": ["获取", "安装", "下载下来", "装上"]
}
from fastapi import FastAPI, Request
from pydantic import BaseModel
from typing import List, Union
import uvicorn
app = FastAPI()
class QueryRequest(BaseModel):
input_str: str
@app.post("/analyze")
async def analyze_input(data: QueryRequest):
input_str = data.input_str.strip()
if judge_define_suffix(input_str):
nlu_info = NLUInfo(vertical_category="软件咨询")
nlu_info.intent_category = "查询功能"
nlu_info.domain_category = "后缀名查询"
suffix_name = match_suffix(input_str)
nlu_info.retrieve_keywords = suffix_name
suffix_to_software = retrieve_relevant_software(suffix_name)
if isinstance(suffix_to_software, int):
return {"response": "booway助手:未查到相关知识"}
query_rewrite = f"{suffix_name}是什么文件?用什么软件打开?"
nlu_info.rewrite = query_rewrite
if isinstance(suffix_to_software, list):
query_kg = '\n'.join(suffix_to_software)
else:
query_kg = suffix_to_software
result = chain_suffix_answers.invoke({"query": input_str, "kg": query_kg})
# return {"response": f"booway助手:{result}"}
return {"response": f"booway助手:{result}",
"nlu_info": nlu_info}
else:
# todo: 多轮对话处理
input_str_stoped = stop_word_processing(input_str, nlp_trf, polite_words) # 根据实际传参调整
input_str_syn = normalize_text(input_str_stoped, synonym_dict)
vertical_category = chain_domain.invoke(input_str_syn)
if vertical_category == "未知":
return {"response": "booway助手:闲聊服务只提供给内测用户"}
nlu_info = NLUInfo(vertical_category="软件咨询")
nlu_info.domain_category = vertical_category
input_str_name_rewrite = chains_name_rewrite.invoke({
"query": input_str,
"software_name": nlu_info.domain_category
})
input_str_rewrite = chain_normal_rewrite.invoke(input_str_name_rewrite)
temp_retriever = get_keywords_v2(input_str_rewrite)
nlu_info.question_type = chain_5W2H.invoke(input_str_rewrite)
nlu_info.intent_category = chain_intention.invoke(input_str)
retrievers, input_csv_path = domain_mapping[nlu_info.domain_category][:3], domain_mapping[nlu_info.domain_category][3]
index_keywords = interface_search(temp_retriever, *retrievers)
nlu_info.rewrite = chain_retrieval_rewrite.invoke({
"query": input_str_rewrite,
"question_type": nlu_info.question_type,
"intention_type": nlu_info.intent_category,
"keywords": index_keywords
})
nlu_info.retrieve_keywords = get_keywords_v3(nlu_info.rewrite)
return {"response": "success",
"nlu_info": nlu_info}
# 可选:本地调试入口
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=3333)