优化意图识别API,移除同步意图识别器,改为使用异步意图识别器,更新相关逻辑以支持异步处理,增强错误处理和日志记录,同时更新请求和响应模型以适应新的API结构。

This commit is contained in:
2025-07-07 17:51:10 +08:00
parent b9bff7f512
commit 1f3e97d081
7 changed files with 146 additions and 910 deletions
+10 -9
View File
@@ -21,7 +21,7 @@ from typing import List, Dict, Any
from langchain.output_parsers import PydanticOutputParser from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from rag2_0.intent_recognition import IntentRecognizer, AsyncIntentRecognizer from rag2_0.intent_recognition import AsyncIntentRecognizer
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
from rag2_0.intent_recognition.DataModels import Classification from rag2_0.intent_recognition.DataModels import Classification
from rag2_0.tool.ModelTool import OpenAiLLM from rag2_0.tool.ModelTool import OpenAiLLM
@@ -78,8 +78,6 @@ class QueryRewriteProcessor:
self.api_key = api_key or os.getenv("OPENAI_API_KEY") self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.base_url = base_url or os.getenv("OPENAI_API_BASE") self.base_url = base_url or os.getenv("OPENAI_API_BASE")
self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo") self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
self.recognizer = IntentRecognizer(api_key=self.api_key, base_url=self.base_url, model_name=self.model_name)
# 使用asyncio.run()运行异步create方法 # 使用asyncio.run()运行异步create方法
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create( self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create(
api_key=self.api_key, api_key=self.api_key,
@@ -198,12 +196,12 @@ class QueryRewriteProcessor:
while retry_count <= max_retries: while retry_count <= max_retries:
try: try:
# 使用process_query方法处理查询 result = asyncio.run(self.recognizer_async.process_query_async(query,
result = self.recognizer.process_query(query,
conversation_context=conversation_context, conversation_context=conversation_context,
chat_history=chat_history, chat_history=chat_history,
previous_slots=previous_slots, previous_slots=previous_slots,
enable_query_expansion=True) enable_query_expansion=True))
# 提取分类信息 # 提取分类信息
classification = result["classification"] classification = result["classification"]
original_query = result["rewrite"]["rewrite"] original_query = result["rewrite"]["rewrite"]
@@ -238,9 +236,9 @@ class QueryRewriteProcessor:
if slot_filling and "filled_data" in slot_filling: if slot_filling and "filled_data" in slot_filling:
# 格式化槽位填充结果 # 格式化槽位填充结果
slot_filling_str = json.dumps({ slot_filling_str = json.dumps({
"是否完整": slot_filling.get("is_complete", False), "is_complete": slot_filling.get("is_complete", False),
"缺失槽位": slot_filling.get("missing_slots", {}), "missing_slots": slot_filling.get("missing_slots", {}),
"填充数据": slot_filling.get("filled_data", {}) "filled_data": slot_filling.get("filled_data", {})
}, ensure_ascii=False, indent=2) }, ensure_ascii=False, indent=2)
# 处理成功,返回结果 # 处理成功,返回结果
@@ -442,9 +440,12 @@ def main():
for idx, query in enumerate(examples): for idx, query in enumerate(examples):
if query.strip() == "": if query.strip() == "":
continue continue
query="储能C1软件如何新建工程?"
conversation_context="当前使用软件:配网计价通D3软件"
# 在调试模式下使用完整的参数 # 在调试模式下使用完整的参数
print(json.dumps(processor.process_query( print(json.dumps(processor.process_query(
query, query,
conversation_context=conversation_context,
enable_retrieval=True enable_retrieval=True
), ensure_ascii=False, indent=2)) ), ensure_ascii=False, indent=2))
-1
View File
@@ -186,7 +186,6 @@ class DifyComparisonTester:
要求 要求
1、分析待评估的回答与标准答案的匹配程度(包括内容、步骤、主体等) 1、分析待评估的回答与标准答案的匹配程度(包括内容、步骤、主体等)
2、如果待评估的回答与标准答案在核心内容和关键信息(步骤)上一致,即使表达方式不同,也应判定为"正确" 2、如果待评估的回答与标准答案在核心内容和关键信息(步骤)上一致,即使表达方式不同,也应判定为"正确"
3、只要大体描述一致,即使缺失了一些步骤,也应判定为"正确"
3、如果待评估的回答存在明显的错误信息,应判定为"错误" 3、如果待评估的回答存在明显的错误信息,应判定为"错误"
4、请严格按json格式输出: 4、请严格按json格式输出:
{{ {{
+132 -57
View File
@@ -1,14 +1,14 @@
from gevent import monkey
monkey.patch_all()
from flask import Flask, request, Response
import os import os
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Dict, List, Any, Optional
import asyncio
from dotenv import load_dotenv from dotenv import load_dotenv
import json import json
import time import time
from gevent.lock import RLock
import datetime import datetime
import logging import logging
# 加载环境变量 # 加载环境变量
@@ -16,7 +16,7 @@ load_dotenv()
import sys import sys
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from rag2_0.intent_recognition import IntentRecognizer from rag2_0.intent_recognition import AsyncIntentRecognizer
logging.basicConfig( logging.basicConfig(
@@ -31,32 +31,87 @@ logging.getLogger('openai').setLevel(logging.WARNING)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
app = Flask(__name__)
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
_instance = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
@app.route('/intent_recognize', methods=['POST'])
def intent_recognize(): # 定义请求模型
class IntentRecognizeRequest(BaseModel):
query: str
conversation_context: str = ""
chat_history: Optional[List] = None
previous_slots: str | Dict = None
# 定义槽位填充响应模型
class SlotFillingResponse(BaseModel):
is_complete: bool = False
missing_slots: Dict[str, Any] = Field(default_factory=dict)
filled_data: Dict[str, Any] = Field(default_factory=dict)
class QueryExpandResponse(BaseModel):
all: List[str] = Field(default_factory=list)
step_back: Dict[str, Any] = Field(default_factory=Dict)
follow_up: Dict[str, Any] = Field(default_factory=Dict)
hyde: Dict[str, Any] = Field(default_factory=Dict)
multi_questions: Dict[str, Any] = Field(default_factory=Dict)
# 定义响应模型
class IntentRecognizeResponse(BaseModel):
source_query: str
source_query_keys: List[str]
vertical_classification: str
sub_classification: str
rewrite_query: str
keywords: List[Dict[str, str]] = Field(default_factory=list)
has_slot_filling: bool = False
slot_filling: SlotFillingResponse = Field(default_factory=SlotFillingResponse)
query_expand: QueryExpandResponse = Field(default_factory=QueryExpandResponse)
# 创建FastAPI应用
app = FastAPI(
title="意图识别服务",
description="基于LLM的意图识别和问题改写服务",
version="2.0"
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局变量存储AsyncIntentRecognizer实例
_instance = None
# 应用启动事件
@app.on_event("startup")
async def startup_event():
global _instance
# 初始化AsyncIntentRecognizer实例
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
_instance = await AsyncIntentRecognizer.create(api_key=api_key, base_url=base_url, model_name=model_name)
logger.info("AsyncIntentRecognizer初始化完成")
@app.post("/intent_recognize", response_model=IntentRecognizeResponse, summary="意图识别", description="识别用户查询的意图并进行问题改写")
async def intent_recognize(request: IntentRecognizeRequest):
try: try:
data = request.get_json(force=True) if not request.query:
query = data.get('query') raise HTTPException(status_code=400, detail="缺少query参数")
conversation_context = data.get('conversation_context', "")
chat_history = data.get('chat_history', None)
previous_slots = data.get('previous_slots', None)
if not query:
return Response(json.dumps({"error": "缺少query参数"}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=400)
start_time = time.time() start_time = time.time()
result = _instance.process_query(query=query, result = await _instance.process_query_async(
conversation_context=conversation_context, query=request.query,
chat_history=chat_history, conversation_context=request.conversation_context,
previous_slots=previous_slots, chat_history=request.chat_history,
use_jieba=False, previous_slots=request.previous_slots,
enable_query_expansion=True) use_jieba=False,
enable_query_expansion=True
)
end_time = time.time() end_time = time.time()
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z") current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
@@ -67,41 +122,61 @@ def intent_recognize():
# 提取关键词信息 # 提取关键词信息
keywords = result["keywords"] keywords = result["keywords"]
keywords_str = "" keywords_list = []
if keywords and keywords.get("terms"): if keywords and keywords.get("terms"):
term_details = []
for term in keywords["terms"]: for term in keywords["terms"]:
term_info = { keywords_list.append({
"名称": term["name"], "名称": term["name"]
} })
term_details.append(term_info)
keywords_str = term_details
# 提取槽位填充信息 # 提取槽位填充信息
slot_filling = result.get("slot_filling", {}) slot_filling = result.get("slot_filling", {})
response_result = { # 构建响应
"source_query": query, response = IntentRecognizeResponse(
"source_query_keys": result["query_keys"], source_query=request.query,
"vertical_classification": classification["vertical_classification"], source_query_keys=result["query_keys"],
"sub_classification": classification["sub_classification"], vertical_classification=classification["vertical_classification"],
"rewrite_query": result["rewrite"]["rewrite"], sub_classification=classification["sub_classification"],
"keywords": keywords_str, rewrite_query=result["rewrite"]["rewrite"],
"has_slot_filling": len(slot_filling)!=0, keywords=keywords_list,
"slot_filling": { has_slot_filling=len(slot_filling) != 0,
"is_complete": slot_filling.get("is_complete", False), slot_filling=SlotFillingResponse(
"missing_slots": slot_filling.get("missing_slots", {}), is_complete=slot_filling.get("is_complete", False),
"filled_data": slot_filling.get("filled_data", {}) missing_slots=slot_filling.get("missing_slots", {}),
}, filled_data=slot_filling.get("filled_data", {})
"query_expand": result["query_expand"] ),
} query_expand=QueryExpandResponse(
return Response(json.dumps(response_result, ensure_ascii=False), content_type='application/json; charset=utf-8') all=result["query_expand"]["all"],
step_back=result["query_expand"]["step_back"],
follow_up=result["query_expand"]["follow_up"],
hyde=result["query_expand"]["hyde"],
multi_questions=result["query_expand"]["multi_questions"]
)
)
return response
except HTTPException as e:
raise e
except Exception as e: except Exception as e:
logger.error(f"意图识别出错: {str(e)}",exc_info=True) logger.error(f"意图识别出错: {str(e)}", exc_info=True)
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500) raise HTTPException(status_code=500, detail=str(e))
# 添加健康检查端点
@app.get("/health", summary="健康检查")
async def health_check():
return {"status": "ok"}
if __name__ == "__main__": if __name__ == "__main__":
# 开发环境使用Flask内置服务 # 使用uvicorn启动服务
# 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app import uvicorn
# uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app uvicorn.run(
app.run(host="0.0.0.0", port=8001, threaded=True) "rag2_0.dify.intent_recognition_api:app",
host="0.0.0.0",
port=8001,
reload=False, # 开发环境启用热重载
workers=1 # 生产环境可以增加worker数量
)
# 生产环境可以使用以下命令启动:
# uvicorn rag2_0.dify.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 10
+1 -4
View File
@@ -225,8 +225,6 @@ class DataProblemSlots(SlotBase):
class FileExtensionConsultingSlots(SlotBase): class FileExtensionConsultingSlots(SlotBase):
file_extension: str = Field(default="", description="文件后缀名") file_extension: str = Field(default="", description="文件后缀名")
operation_purpose: str = Field(default="", description="操作目的(了解对应软件,对应工程)") operation_purpose: str = Field(default="", description="操作目的(了解对应软件,对应工程)")
file_source: Optional[str] = Field(default="", description="文件来源场景")
related_software: Optional[str] = Field(default="", description="相关软件名称")
def check_required_slots(self) -> Tuple[bool, Dict[str, str]]: def check_required_slots(self) -> Tuple[bool, Dict[str, str]]:
"""检查必填槽位是否都存在""" """检查必填槽位是否都存在"""
@@ -239,7 +237,7 @@ class FileExtensionConsultingSlots(SlotBase):
# 3.2 软件锁类 # 3.2 软件锁类
class SoftwareLockSlots(SlotBase): class SoftwareLockSlots(SlotBase):
lock_type: str = Field(default="", description="锁类型(单机锁、网络锁)") lock_type: str = Field(default="单机锁", description="锁类型(单机锁、网络锁)")
operation_purpose: str = Field(default="", description="操作目的(查询锁信息、无法激活、无法注册)") operation_purpose: str = Field(default="", description="操作目的(查询锁信息、无法激活、无法注册)")
lock_number: Optional[str] = Field(default="", description="软件锁编号/注册号") lock_number: Optional[str] = Field(default="", description="软件锁编号/注册号")
@@ -259,7 +257,6 @@ class InstallationDownloadSlots(SlotBase):
file_name: str = Field(default="", description="文件名,与software_name二选一") file_name: str = Field(default="", description="文件名,与software_name二选一")
operation_stage: str = Field(default="", description="操作阶段(下载、安装等)") operation_stage: str = Field(default="", description="操作阶段(下载、安装等)")
os_version: Optional[str] = Field(default="", description="操作系统版本") os_version: Optional[str] = Field(default="", description="操作系统版本")
package_source: Optional[str] = Field(default="", description="安装包来源/版本号")
def check_required_slots(self) -> Tuple[bool, Dict[str, str]]: def check_required_slots(self) -> Tuple[bool, Dict[str, str]]:
"""检查必填槽位是否都存在""" """检查必填槽位是否都存在"""
+1 -834
View File
@@ -38,839 +38,6 @@ from .DataModels import (
from .ProfessionalNounVector import ProfessionalNounRetriever, AsyncProfessionalNounRetriever from .ProfessionalNounVector import ProfessionalNounRetriever, AsyncProfessionalNounRetriever
from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM, SiliconFlowReRankerModel from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM, SiliconFlowReRankerModel
class IntentRecognizer:
"""
意图识别和问题改写类
"""
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo", vector_index_dir: str = None):
"""
初始化意图识别器
Args:
api_key: OpenAI API密钥,如果为None则从环境变量获取
base_url: OpenAI API基础URL,如果为None则使用默认URL
model_name: 要使用的模型名称
vector_index_dir: 向量索引目录,如果为None则使用默认目录
"""
# 初始化LLM
llm_params = {
"temperature": 0.2, # 降低随机性,使结果更确定
"top_p": 0.7,
"model": model_name
}
# 如果提供了API密钥,则使用提供的密钥
if api_key:
llm_params["api_key"] = api_key
# 如果提供了自定义URL,则使用提供的URL
if base_url:
llm_params["base_url"] = base_url
self._llm = OpenAiLLM(**llm_params)
# 加载suffix关键词
self._suffix_keywords = self._load_suffix_keywords()
# 初始化向量检索器
self._noun_retriever = ProfessionalNounRetriever(api_key=api_key, index_dir=vector_index_dir)
def _load_suffix_keywords(self, filepath: str = None) -> List[str]:
"""
加载后缀关键词列表
Args:
filepath: 后缀关键词文件路径,默认为None使用默认路径
Returns:
后缀关键词列表
"""
try:
# 如果未指定路径,使用默认路径
if filepath is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
filepath = os.path.join(current_dir, "..", "..", "data", "nouns", "suffix_keywords.json")
# 读取JSON文件
with open(filepath, "r", encoding="utf-8") as f:
suffix_data = json.load(f)
# 添加额外的固定后缀
return suffix_data
except Exception as e:
raise RuntimeError(f"加载后缀关键词失败: {e}") from e
def _classify_intent(self, query: str, conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None) -> Classification:
"""
对用户输入进行意图分类
Args:
content: 用户输入内容
keywords: 匹配到的关键词列表
rewrite: 重写的问题
Returns:
分类结果
"""
classification_start_time = time.time()
classification_parser = PydanticOutputParser(pydantic_object=Classification)
formatted_prompt = classification_prompt.format(user_input=query,
classification_info=classification_info,
output_format=classification_parser.get_format_instructions(),
conversation_context=conversation_context,
chat_history=json.dumps(chat_history, ensure_ascii=False))
# 解析输出
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
classification_end_time = time.time()
classification_time = classification_end_time - classification_start_time
logging.info(f"意图分类耗时统计 - 总耗时: {classification_time:.2f}")
# 尝试直接解析JSON响应
parsed_output = classification_parser.parse(response.content.strip())
return parsed_output
except Exception as e:
raise RuntimeError(f"解析分类结果时出错: {e}") from e
def _tokenize_with_jieba(self, query: str) -> List[str]:
"""
使用jieba分词器对查询进行分词
Args:
query: 用户查询
Returns:
分词后的词语列表
"""
# 使用jieba进行分词
seg_list = jieba.cut(query, cut_all=False)
# 过滤掉停用词和标点符号
filtered_tokens = []
for token in seg_list:
# 过滤掉空格和标点符号
if token.strip() and not re.match(r'^[^\w\s]$', token):
filtered_tokens.append(token)
return filtered_tokens
def _extract_keywords_with_llm(self, query: str, use_jieba: bool = False) -> List[Term]:
"""
使用LLM从用户查询中提取专业关键词
Args:
query: 用户查询
use_jieba: 是否使用jieba分词辅助提取关键词
Returns:
提取的术语列表
"""
# 如果使用jieba分词
if use_jieba:
# 先使用jieba分词
tokens = self._tokenize_with_jieba(query)
# 构建术语列表
terms = []
for token in tokens:
if len(token) > 1: # 过滤掉单字词
terms.append(Term(name=token, synonymous=[], description=""))
return terms
else:
# 使用LLM提取关键词
# 准备提示词
formatted_prompt = extract_nouns_prompt.replace("{content}", query)
terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
formatted_prompt = formatted_prompt.replace("{output_format}", terms_list_parser.get_format_instructions())
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 尝试使用Pydantic解析器解析TermList
parsed_output = terms_list_parser.parse(response.content)
return parsed_output.terms
def _rerank_matched_terms(self, query_key: str, matched_terms: set, top_k: int = 2, rerank_score:float = 0.6) -> List[Term]:
"""
对召回的专业术语进行重排序,按与用户查询的相关性排序
Args:
query: 用户查询
matched_terms: 匹配到的专业术语集合
query_keys: 用户查询中提取的关键词列表
Returns:
重排序后的专业术语列表
"""
if not matched_terms:
return []
if len(matched_terms) <= top_k:
return list(matched_terms)
try:
# 将每个术语转换为可用于重排序的文本表示
# term_texts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) + "|" + "描述:" + term.description for term in matched_terms]
term_texts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) for term in matched_terms]
# 使用重排序模型
xinference_reranker = XinferenceReRankerModel()
rerank_results = xinference_reranker.rerank(query_key, term_texts, top_k=top_k)
# 将matched_terms转换为列表以便按索引访问
matched_terms_list = list(matched_terms)
# 根据重排序结果获取排序后的术语列表
reranked_terms = [matched_terms_list[result["index"]] for result in rerank_results if result["score"] >= rerank_score]
return reranked_terms
except Exception as e:
raise RuntimeError(f"_rerank_matched_terms重排失败:{e}") from e
def _match_keywords(self, query: str, use_jieba: bool = False) -> Tuple[TermList, List[str]]:
"""
从用户问题中匹配关键词,结合LLM提取和向量检索
Args:
query: 用户问题
use_jieba: 是否使用jieba分词辅助提取关键词
Returns:
匹配到的关键词列表
"""
start_time = time.time()
query_keys=[]
# 步骤1: 使用LLM提取查询中的关键词
try:
llm_start_time = time.time()
extracted_terms = self._extract_keywords_with_llm(query, use_jieba)
for term in extracted_terms:
query_keys.append(term.name)
llm_end_time = time.time()
llm_time = llm_end_time - llm_start_time
except Exception as e:
raise RuntimeError(f"LLM关键词提取失败: {e}") from e
matched_terms = [] # 存储匹配到的Term对象
# 步骤2: 使用向量检索找到相似的专业名词
try:
vector_start_time = time.time()
# 对matched_terms中的每个关键字进行向量检索
for current_key in query_keys:
vector_results = self._noun_retriever.query(current_key, top_k=5, use_intersection=False)
current_key_terms = set()
# 添加向量检索结果
for result in vector_results:
if isinstance(result.get('synonymous', []), str):
result['synonymous'] = result['synonymous'].split(';')
term = Term(
name=result.get('name'),
synonymous=result.get('synonymous', []),
description=result.get('description', '')
)
current_key_terms.add(term)
if len(current_key_terms) > 0:
reranked_terms = self._rerank_matched_terms(current_key, current_key_terms)
matched_terms.extend(reranked_terms)
vector_end_time = time.time()
vector_time = vector_end_time - vector_start_time
except Exception as e:
raise RuntimeError(f"向量检索关键词时出错: {e}") from e
# 提取所有Term对象的名称并排序
# 将set类型的matched_terms转换为TermList类型
term_list = TermList(terms=list(matched_terms))
end_time = time.time()
total_time = end_time - start_time
# 输出整合的时间日志
logging.info(f"关键词匹配耗时统计 - 总耗时: {total_time:.2f}秒, 问题关键词提取: {llm_time:.2f}秒, 向量检索+重排序: {vector_time:.2f}")
return term_list, query_keys
def _rewrite_query(self, query: str, keywords: TermList, query_keys:List[str], chat_history: List[Dict[str, str]] = None, context: str = "") -> QueryRewrite:
"""
对用户问题进行改写
Args:
query: 用户原始问题
keywords: 匹配到的关键词列表
query_keys: 用户查询中提取的关键词列表
Returns:
改写结果
"""
rewrite_start_time = time.time()
# 准备问题改写提示
terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
# terms_dict = [term.model_dump() for term in keywords.terms]
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
# formatted_prompt = query_rewrite_prompt.format(query=query,
# output_format=query_rewrite_parser.get_format_instructions(),
# keywords=keywords_str)
formatted_prompt = query_rewrite_prompt_pro.format(query=query,
output_format=query_rewrite_parser.get_format_instructions(),
keywords=keywords_str,
chat_history=chat_history,
context=context)
# 解析输出
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 尝试直接解析JSON响应
parsed_output = query_rewrite_parser.parse(response.content)
rewrite_end_time = time.time()
rewrite_time = rewrite_end_time - rewrite_start_time
logging.info(f"问题改写耗时统计 - 总耗时: {rewrite_time:.2f}")
return parsed_output
except Exception as e:
raise RuntimeError(f"解析问题改写结果时出错: {e}") from e
def _judge_define_suffix(self, input_str: str) -> Tuple[bool, List[str]]:
"""
判断输入字符串是否包含定义的后缀,并返回所有匹配到的后缀名列表
Args:
input_str: 输入字符串
Returns:
Tuple[bool, List[str]]: (是否包含定义的后缀, 匹配到的后缀名列表)
"""
# 构建正则表达式模式,匹配大小写不敏感且前面可能带有.
pattern = r'(?:\.?)(' + '|'.join(re.escape(field.get('name')) for field in self._suffix_keywords) + r')'
# 使用 re.IGNORECASE 标志来忽略大小写,findall找到所有匹配
matches = re.finditer(pattern, input_str, re.IGNORECASE)
matched_suffixes = [match.group(1) for match in matches]
return bool(matched_suffixes), matched_suffixes
def process_query(self, query: str, conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None,
use_jieba: bool = False,
enable_query_expansion: bool = False) -> Dict[str, Any]:
"""
处理用户问题的完整流程
Args:
query: 用户原始问题
conversation_context: 会话背景信息
chat_history: 历史对话记录,格式为[{"user": "content"}, {"assistant": "content"}]
previous_slots: 历史槽位信息
use_jieba: 是否使用jieba分词辅助提取关键词
Returns:
包含分类、关键词、改写和槽位填充结果的字典
"""
# 是否是扩展名
# is_suffix, matched_suffixes = self._judge_define_suffix(query)
# if is_suffix:
# # 将所有匹配到的后缀名作为Term添加到结果中
# suffix_terms = []
# for suffix in matched_suffixes:
# term_dict = next((item for item in self._suffix_keywords if item['name'].lower() == suffix.lower()), None)
# if term_dict:
# suffix_term = Term(
# name=term_dict.get('name'),
# synonymous=term_dict.get('synonymous', []),
# description=json.dumps(term_dict.get('description', ''), ensure_ascii=False)
# )
# suffix_terms.append(suffix_term)
# return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes
if chat_history is None:
chat_history = []
if previous_slots is None:
previous_slots = {}
# 步骤: 并行执行提问扩展
if enable_query_expansion:
# 创建线程和结果容器
threads_and_results = [
# 5.1: 后退提示
self._run_in_thread(self._generate_step_back_prompt, args=(query, chat_history, conversation_context)),
# 5.2: Follow Up Questions
self._run_in_thread(self._generate_follow_up_questions, args=(query, chat_history, conversation_context)),
# 5.3: HyDE
self._run_in_thread(self._generate_hypothetical_document, args=(query, chat_history, conversation_context)),
# 5.4: 多问题查询
self._run_in_thread(self._generate_multi_questions, args=(query, chat_history, conversation_context))
]
# 步骤1: 匹配关键词
keywords_terms, query_keys = self._match_keywords(query, use_jieba)
# 步骤2: 问题改写
rewrite = self._rewrite_query(
query=query,
keywords=keywords_terms,
query_keys=query_keys,
chat_history=chat_history,
context=conversation_context
)
# 步骤3: 进行意图识别和槽位填充
# result = self._process_intent_and_slot(rewrite.rewrite, conversation_context, chat_history, previous_slots)
# result.update({"keywords": keywords_terms.model_dump(),
# "rewrite": rewrite.model_dump(),
# "query_keys": query_keys})
# return result
# 步骤3: 进行意图分类
classification = self._classify_intent(rewrite.rewrite, conversation_context, chat_history, previous_slots)
# 步骤4: 进行槽位填充
# 如果是有效分类,进行槽位填充
slot_filling_result = {}
if classification.vertical_classification not in ["其他", "闲聊"] and classification.sub_classification not in ["其他", "闲聊"]:
slot_filling_result = self._fill_slots(rewrite.rewrite, classification, conversation_context, chat_history, previous_slots)
if not enable_query_expansion:
return {
"classification": classification.model_dump(),
"keywords": keywords_terms.model_dump(),
"rewrite": rewrite.model_dump(),
"query_keys": query_keys,
"slot_filling": slot_filling_result
}
# 等待所有线程完成
start_time = time.time()
for thread, _ in threads_and_results:
thread.join()
end_time = time.time()
logging.info(f"问题扩展环节耗时统计 - 总耗时: {end_time - start_time:.2f}")
# 收集结果
step_back_result = threads_and_results[0][1][0] if threads_and_results[0][1] else StepBackPrompt(original_query=query, step_back_query=query)
follow_up_result = threads_and_results[1][1][0] if threads_and_results[1][1] else FollowUpQuestions(original_query=query, follow_up_query=query)
hyde_result = threads_and_results[2][1][0] if threads_and_results[2][1] else HypotheticalDocument(original_query=query, hypothetical_answer="")
multi_questions_result = threads_and_results[3][1][0] if threads_and_results[3][1] else MultiQuestions(original_query=query, sub_questions=[query])
all_questions=multi_questions_result.sub_questions
all_questions.append(query)
all_questions.append(step_back_result.step_back_query)
all_questions.append(follow_up_result.follow_up_query)
all_questions.append(hyde_result.hypothetical_answer)
all_questions = list(set(all_questions))
query_expand={"all":all_questions,
"step_back":step_back_result.model_dump(),
"follow_up":follow_up_result.model_dump(),
"hyde":hyde_result.model_dump(),
"multi_questions":multi_questions_result.model_dump()}
# 返回所有结果
return {
"classification": classification.model_dump(),
"keywords": keywords_terms.model_dump(),
"rewrite": rewrite.model_dump(),
"query_keys": query_keys,
"slot_filling": slot_filling_result,
"query_expand": query_expand
}
def _fill_slots(self, query: str, classification: Classification, conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None,) -> Dict[str, Any]:
"""
根据分类结果对问题进行槽位填充
Args:
query: 用户原始问题
classification: 意图分类结果
keywords: 匹配的关键词列表
Returns:
填充后的槽位数据模型
"""
# 根据分类结果选择对应的数据模型
slot_model = self._get_slot_model(classification)
if not slot_model:
raise RuntimeError("未找到匹配的槽位模型")
fill_slots_start_time = time.time()
# 使用LLM进行槽位填充
filled_slots = self._fill_slots_with_llm(query, classification, slot_model, conversation_context, chat_history, previous_slots)
fill_slots_end_time = time.time()
fill_slots_time = fill_slots_end_time - fill_slots_start_time
logging.info(f"槽位填充耗时统计 - 总耗时: {fill_slots_time:.2f}")
# 检查必填槽位是否都已填充
is_complete, missing_slots = filled_slots.check_required_slots()
return {
"is_complete": is_complete,
"missing_slots": missing_slots,
"filled_data": filled_slots.model_dump()
}
def _get_slot_model(self, classification: Classification) -> Optional[type]:
"""
根据分类结果获取对应的槽位模型类,用于统一提示词处理
Args:
classification: 意图分类结果
Returns:
对应的槽位模型类
"""
# 软件问题
if classification.vertical_classification == "软件问题":
if classification.sub_classification == "软件功能":
return SoftwareFunctionSlots
elif classification.sub_classification == "故障排查":
return SoftwareTroubleShootingSlots
# 业务问题
elif classification.vertical_classification == "业务问题":
if classification.sub_classification == "专业咨询":
return ProfessionalConsultingSlots
elif classification.sub_classification == "数据问题":
return DataProblemSlots
# 安装下载注册
elif classification.vertical_classification == "安装下载注册":
if classification.sub_classification == "后缀名咨询":
return FileExtensionConsultingSlots
elif classification.sub_classification == "软件锁类":
return SoftwareLockSlots
elif classification.sub_classification == "安装下载类":
return InstallationDownloadSlots
elif classification.sub_classification == "问题排查类":
return ProblemDiagnosisSlots
# 其他
elif classification.vertical_classification == "其他":
return OtherSlots
return None
def _fill_slots_with_llm(self, query: str,
classification: Classification,
slot_model_class: type,
conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None) -> Any:
"""
使用LLM进行槽位填充
Args:
query: 用户原始问题
classification: 意图分类结果
slot_model_class: 槽位模型类
Returns:
填充后的槽位数据模型实例
"""
# 准备提示词
slot_parser = PydanticOutputParser(pydantic_object=slot_model_class)
formatted_prompt = slot_filling_prompt.format(
query=query,
vertical_classification=classification.vertical_classification,
sub_classification=classification.sub_classification,
output_format=slot_parser.get_format_instructions(),
conversation_context=conversation_context,
chat_history=json.dumps(chat_history,ensure_ascii=False),
previous_slots=json.dumps(previous_slots,ensure_ascii=False),
)
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 尝试解析LLM响应
parsed_output = slot_parser.parse(response.content)
return parsed_output
except Exception as e:
# 如果解析失败,创建一个空的模型实例
empty_instance = slot_model_class()
return empty_instance
def _generate_step_back_prompt(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> StepBackPrompt:
"""
生成后退提示
Args:
query: 用户原始问题
chat_history: 历史对话记录
conversation_context: 会话背景信息
Returns:
后退提示结果
"""
step_back_start_time = time.time()
# 准备提示词
step_back_parser = PydanticOutputParser(pydantic_object=StepBackPrompt)
formatted_prompt = step_back_prompt.format(
query=query,
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
conversation_context=conversation_context,
output_format=step_back_parser.get_format_instructions()
)
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 解析输出
parsed_output = step_back_parser.parse(response.content)
step_back_end_time = time.time()
step_back_time = step_back_end_time - step_back_start_time
logging.debug(f"后退提示生成耗时统计 - 总耗时: {step_back_time:.2f}")
return parsed_output
except Exception as e:
# 如果解析失败,返回原始查询作为后退提示
logging.error(f"后退提示生成失败: {e}", exc_info=True)
return StepBackPrompt(original_query=query, step_back_query=query)
def _generate_follow_up_questions(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> FollowUpQuestions:
"""
生成后续问题
Args:
query: 用户原始问题
chat_history: 历史对话记录
conversation_context: 会话背景信息
Returns:
后续问题结果
"""
follow_up_start_time = time.time()
# 准备提示词
follow_up_parser = PydanticOutputParser(pydantic_object=FollowUpQuestions)
formatted_prompt = follow_up_questions_prompt.format(
query=query,
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
conversation_context=conversation_context,
output_format=follow_up_parser.get_format_instructions()
)
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 解析输出
parsed_output = follow_up_parser.parse(response.content)
follow_up_end_time = time.time()
follow_up_time = follow_up_end_time - follow_up_start_time
logging.debug(f"后续问题生成耗时统计 - 总耗时: {follow_up_time:.2f}")
return parsed_output
except Exception as e:
# 如果解析失败,返回原始查询作为后续问题
logging.error(f"后续问题生成失败: {e}", exc_info=True)
return FollowUpQuestions(original_query=query, follow_up_query=query)
def _generate_hypothetical_document(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> HypotheticalDocument:
"""
生成假设性文档
Args:
query: 用户原始问题
chat_history: 历史对话记录
conversation_context: 会话背景信息
Returns:
假设性文档结果
"""
hyde_start_time = time.time()
# 准备提示词
hyde_parser = PydanticOutputParser(pydantic_object=HypotheticalDocument)
formatted_prompt = hyde_prompt.format(
query=query,
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
conversation_context=conversation_context,
output_format=hyde_parser.get_format_instructions()
)
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 解析输出
parsed_output = hyde_parser.parse(response.content)
hyde_end_time = time.time()
hyde_time = hyde_end_time - hyde_start_time
logging.debug(f"假设性文档生成耗时统计 - 总耗时: {hyde_time:.2f}")
return parsed_output
except Exception as e:
# 如果解析失败,返回空的假设性回答
logging.error(f"假设性文档生成失败: {e}", exc_info=True)
return HypotheticalDocument(original_query=query, hypothetical_answer="")
def _generate_multi_questions(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> MultiQuestions:
"""
生成多角度问题
Args:
query: 用户原始问题
chat_history: 历史对话记录
conversation_context: 会话背景信息
Returns:
多角度问题结果
"""
multi_questions_start_time = time.time()
# 准备提示词
multi_questions_parser = PydanticOutputParser(pydantic_object=MultiQuestions)
formatted_prompt = multi_questions_prompt.format(
query=query,
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
conversation_context=conversation_context,
output_format=multi_questions_parser.get_format_instructions()
)
try:
# 调用LLM
response = self._llm.invoke(formatted_prompt, False)
# 解析输出
parsed_output = multi_questions_parser.parse(response.content)
multi_questions_end_time = time.time()
multi_questions_time = multi_questions_end_time - multi_questions_start_time
logging.debug(f"多角度问题生成耗时统计 - 总耗时: {multi_questions_time:.2f}")
return parsed_output
except Exception as e:
# 如果解析失败,返回原始查询作为唯一子问题
logging.error(f"多角度问题生成失败: {e}",exc_info=True)
return MultiQuestions(original_query=query, sub_questions=[query])
def _run_in_thread(self, func, args=(), kwargs={}):
"""
在线程中执行函数并返回结果
Args:
func: 要执行的函数
args: 函数的位置参数
kwargs: 函数的关键字参数
Returns:
(thread, result_container): 线程对象和存放结果的容器
"""
result_container = []
def thread_target():
try:
result = func(*args, **kwargs)
result_container.append(result)
except Exception as e:
logging.error(f"线程执行函数 {func.__name__} 时出错: {e}", exc_info=True)
result_container.append(None)
thread = threading.Thread(target=thread_target)
thread.start()
return thread, result_container
def _process_intent_and_slot(self, user_input: str, conversation_context: str = "",
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, Any] = None) -> Dict[str, Any]:
"""
使用统一提示词同时进行意图识别和槽位填充
Args:
user_input: 当前用户输入
conversation_context: 会话背景信息
chat_history: 历史对话记录,格式为[{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
previous_slots: 历史槽位信息
Returns:
包含意图分类和槽位填充结果的字典
"""
# 初始化默认值
if chat_history is None:
chat_history = []
if previous_slots is None:
previous_slots = {}
# 生成槽位映射文档
slot_mapping_doc = generate_slot_mapping_doc()
# 准备提示词
parser = PydanticOutputParser(pydantic_object=IntentAndSlotResult)
formatted_prompt = intent_and_slot_prompt.format(
conversation_context=conversation_context,
chat_history=json.dumps(chat_history, ensure_ascii=False),
previous_slots=json.dumps(previous_slots, ensure_ascii=False),
user_input=user_input,
slot_mapping_doc=slot_mapping_doc,
output_format=parser.get_format_instructions(),
classification_info=classification_info
)
# 调用LLM
llm_start_time = time.time()
response = self._llm.invoke(formatted_prompt + output_example, False)
llm_end_time = time.time()
llm_time = llm_end_time - llm_start_time
try:
# 解析LLM响应为JSON
result_json = parser.parse(response.content)
classification=result_json.classification
slot_filling=result_json.slots
is_complete, missing_slots = slot_filling.check_required_slots()
expected_slot_model = self._get_slot_model(classification)
# 添加容错处理,发生概率较低,但仍需处理
if expected_slot_model is None:
# 添加容错处理,应对LLM返回错误分类信息,一级分类跟二级分类错乱
# 重新分类
classification = self._classify_intent(user_input, conversation_context, chat_history, previous_slots)
fill_slots = self._fill_slots(user_input, classification, conversation_context, chat_history, previous_slots)
result = {
"classification": classification.model_dump(),
"slot_filling": fill_slots
}
logging.warning(f"重新分类与槽点填充")
return result
elif expected_slot_model.__name__ != type(slot_filling).__name__:
# 添加容错处理,应对LLM槽位与分类不匹配。重新填充槽位
slot_filling = self._fill_slots(user_input, classification, conversation_context, chat_history, previous_slots)
result = {
"classification": classification.model_dump(),
"slot_filling": slot_filling
}
logging.warning(f"重新填充槽点")
return result
logging.info(f"意图识别+槽位LLM调用耗时: {llm_time:.2f}")
# 构建最终结果
result = {
"classification": classification.model_dump(),
"slot_filling": {
"is_complete": is_complete,
"missing_slots": missing_slots,
"filled_data": slot_filling.model_dump()
}
}
return result
except Exception as e:
raise RuntimeError(f"process_intent_and_slot error:{e}") from e
class AsyncIntentRecognizer: class AsyncIntentRecognizer:
""" """
异步意图识别和问题改写类 异步意图识别和问题改写类
@@ -976,7 +143,7 @@ class AsyncIntentRecognizer:
formatted_prompt = classification_prompt.format(user_input=query, formatted_prompt = classification_prompt.format(user_input=query,
classification_info=classification_info, classification_info=classification_info,
output_format=classification_parser.get_format_instructions(), output_format=classification_parser.get_format_instructions(),
conversation_context=conversation_context, # conversation_context=conversation_context,
chat_history=json.dumps(chat_history, ensure_ascii=False)) chat_history=json.dumps(chat_history, ensure_ascii=False))
# 解析输出 # 解析输出
try: try:
+1 -4
View File
@@ -56,12 +56,9 @@ classification_info="""【垂直领域分类】:
1. 其他""" 1. 其他"""
classification_prompt=""" classification_prompt="""
用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容,将其归类为以下垂直领域之一: 用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容集合历史对话(如果存在),将其归类为以下垂直领域之一:
{classification_info} {classification_info}
## 【会话背景信息】
{conversation_context}
## 【历史对话记录】 ## 【历史对话记录】
{chat_history} {chat_history}
+1 -1
View File
@@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
from .ProfessionalNounVector import ProfessionalNounVectorizer, ProfessionalNounRetriever from .ProfessionalNounVector import ProfessionalNounVectorizer, ProfessionalNounRetriever
from .IntentRecognition import IntentRecognizer, AsyncIntentRecognizer from .IntentRecognition import AsyncIntentRecognizer
from .DataModels import Term, TermList, Classification, QueryRewrite from .DataModels import Term, TermList, Classification, QueryRewrite