优化意图识别API,移除同步意图识别器,改为使用异步意图识别器,更新相关逻辑以支持异步处理,增强错误处理和日志记录,同时更新请求和响应模型以适应新的API结构。
This commit is contained in:
@@ -21,7 +21,7 @@ from typing import List, Dict, Any
|
|||||||
from langchain.output_parsers import PydanticOutputParser
|
from langchain.output_parsers import PydanticOutputParser
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from rag2_0.intent_recognition import IntentRecognizer, AsyncIntentRecognizer
|
from rag2_0.intent_recognition import AsyncIntentRecognizer
|
||||||
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
||||||
from rag2_0.intent_recognition.DataModels import Classification
|
from rag2_0.intent_recognition.DataModels import Classification
|
||||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||||
@@ -78,8 +78,6 @@ class QueryRewriteProcessor:
|
|||||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||||
self.base_url = base_url or os.getenv("OPENAI_API_BASE")
|
self.base_url = base_url or os.getenv("OPENAI_API_BASE")
|
||||||
self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||||
|
|
||||||
self.recognizer = IntentRecognizer(api_key=self.api_key, base_url=self.base_url, model_name=self.model_name)
|
|
||||||
# 使用asyncio.run()运行异步create方法
|
# 使用asyncio.run()运行异步create方法
|
||||||
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create(
|
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create(
|
||||||
api_key=self.api_key,
|
api_key=self.api_key,
|
||||||
@@ -198,12 +196,12 @@ class QueryRewriteProcessor:
|
|||||||
|
|
||||||
while retry_count <= max_retries:
|
while retry_count <= max_retries:
|
||||||
try:
|
try:
|
||||||
# 使用process_query方法处理查询
|
result = asyncio.run(self.recognizer_async.process_query_async(query,
|
||||||
result = self.recognizer.process_query(query,
|
|
||||||
conversation_context=conversation_context,
|
conversation_context=conversation_context,
|
||||||
chat_history=chat_history,
|
chat_history=chat_history,
|
||||||
previous_slots=previous_slots,
|
previous_slots=previous_slots,
|
||||||
enable_query_expansion=True)
|
enable_query_expansion=True))
|
||||||
|
|
||||||
# 提取分类信息
|
# 提取分类信息
|
||||||
classification = result["classification"]
|
classification = result["classification"]
|
||||||
original_query = result["rewrite"]["rewrite"]
|
original_query = result["rewrite"]["rewrite"]
|
||||||
@@ -238,9 +236,9 @@ class QueryRewriteProcessor:
|
|||||||
if slot_filling and "filled_data" in slot_filling:
|
if slot_filling and "filled_data" in slot_filling:
|
||||||
# 格式化槽位填充结果
|
# 格式化槽位填充结果
|
||||||
slot_filling_str = json.dumps({
|
slot_filling_str = json.dumps({
|
||||||
"是否完整": slot_filling.get("is_complete", False),
|
"is_complete": slot_filling.get("is_complete", False),
|
||||||
"缺失槽位": slot_filling.get("missing_slots", {}),
|
"missing_slots": slot_filling.get("missing_slots", {}),
|
||||||
"填充数据": slot_filling.get("filled_data", {})
|
"filled_data": slot_filling.get("filled_data", {})
|
||||||
}, ensure_ascii=False, indent=2)
|
}, ensure_ascii=False, indent=2)
|
||||||
|
|
||||||
# 处理成功,返回结果
|
# 处理成功,返回结果
|
||||||
@@ -442,9 +440,12 @@ def main():
|
|||||||
for idx, query in enumerate(examples):
|
for idx, query in enumerate(examples):
|
||||||
if query.strip() == "":
|
if query.strip() == "":
|
||||||
continue
|
continue
|
||||||
|
query="储能C1软件如何新建工程?"
|
||||||
|
conversation_context="当前使用软件:配网计价通D3软件"
|
||||||
# 在调试模式下使用完整的参数
|
# 在调试模式下使用完整的参数
|
||||||
print(json.dumps(processor.process_query(
|
print(json.dumps(processor.process_query(
|
||||||
query,
|
query,
|
||||||
|
conversation_context=conversation_context,
|
||||||
enable_retrieval=True
|
enable_retrieval=True
|
||||||
), ensure_ascii=False, indent=2))
|
), ensure_ascii=False, indent=2))
|
||||||
|
|
||||||
|
|||||||
@@ -186,7 +186,6 @@ class DifyComparisonTester:
|
|||||||
要求
|
要求
|
||||||
1、分析待评估的回答与标准答案的匹配程度(包括内容、步骤、主体等)
|
1、分析待评估的回答与标准答案的匹配程度(包括内容、步骤、主体等)
|
||||||
2、如果待评估的回答与标准答案在核心内容和关键信息(步骤)上一致,即使表达方式不同,也应判定为"正确"。
|
2、如果待评估的回答与标准答案在核心内容和关键信息(步骤)上一致,即使表达方式不同,也应判定为"正确"。
|
||||||
3、只要大体描述一致,即使缺失了一些步骤,也应判定为"正确"。
|
|
||||||
3、如果待评估的回答存在明显的错误信息,应判定为"错误"。
|
3、如果待评估的回答存在明显的错误信息,应判定为"错误"。
|
||||||
4、请严格按json格式输出:
|
4、请严格按json格式输出:
|
||||||
{{
|
{{
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
from gevent import monkey
|
|
||||||
monkey.patch_all()
|
|
||||||
|
|
||||||
from flask import Flask, request, Response
|
|
||||||
import os
|
import os
|
||||||
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Dict, List, Any, Optional
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
from gevent.lock import RLock
|
|
||||||
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import logging
|
import logging
|
||||||
# 加载环境变量
|
# 加载环境变量
|
||||||
@@ -16,7 +16,7 @@ load_dotenv()
|
|||||||
|
|
||||||
import sys
|
import sys
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from rag2_0.intent_recognition import IntentRecognizer
|
from rag2_0.intent_recognition import AsyncIntentRecognizer
|
||||||
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
@@ -31,32 +31,87 @@ logging.getLogger('openai').setLevel(logging.WARNING)
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
app = Flask(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 定义请求模型
|
||||||
|
class IntentRecognizeRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
conversation_context: str = ""
|
||||||
|
chat_history: Optional[List] = None
|
||||||
|
previous_slots: str | Dict = None
|
||||||
|
|
||||||
|
# 定义槽位填充响应模型
|
||||||
|
class SlotFillingResponse(BaseModel):
|
||||||
|
is_complete: bool = False
|
||||||
|
missing_slots: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
filled_data: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
class QueryExpandResponse(BaseModel):
|
||||||
|
all: List[str] = Field(default_factory=list)
|
||||||
|
step_back: Dict[str, Any] = Field(default_factory=Dict)
|
||||||
|
follow_up: Dict[str, Any] = Field(default_factory=Dict)
|
||||||
|
hyde: Dict[str, Any] = Field(default_factory=Dict)
|
||||||
|
multi_questions: Dict[str, Any] = Field(default_factory=Dict)
|
||||||
|
|
||||||
|
# 定义响应模型
|
||||||
|
class IntentRecognizeResponse(BaseModel):
|
||||||
|
source_query: str
|
||||||
|
source_query_keys: List[str]
|
||||||
|
vertical_classification: str
|
||||||
|
sub_classification: str
|
||||||
|
rewrite_query: str
|
||||||
|
keywords: List[Dict[str, str]] = Field(default_factory=list)
|
||||||
|
has_slot_filling: bool = False
|
||||||
|
slot_filling: SlotFillingResponse = Field(default_factory=SlotFillingResponse)
|
||||||
|
query_expand: QueryExpandResponse = Field(default_factory=QueryExpandResponse)
|
||||||
|
|
||||||
|
# 创建FastAPI应用
|
||||||
|
app = FastAPI(
|
||||||
|
title="意图识别服务",
|
||||||
|
description="基于LLM的意图识别和问题改写服务",
|
||||||
|
version="2.0"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加CORS中间件
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 全局变量存储AsyncIntentRecognizer实例
|
||||||
|
_instance = None
|
||||||
|
|
||||||
|
# 应用启动事件
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def startup_event():
|
||||||
|
global _instance
|
||||||
|
# 初始化AsyncIntentRecognizer实例
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
base_url = os.getenv("OPENAI_API_BASE")
|
base_url = os.getenv("OPENAI_API_BASE")
|
||||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||||
_instance = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
_instance = await AsyncIntentRecognizer.create(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||||
|
logger.info("AsyncIntentRecognizer初始化完成")
|
||||||
|
|
||||||
@app.route('/intent_recognize', methods=['POST'])
|
@app.post("/intent_recognize", response_model=IntentRecognizeResponse, summary="意图识别", description="识别用户查询的意图并进行问题改写")
|
||||||
def intent_recognize():
|
async def intent_recognize(request: IntentRecognizeRequest):
|
||||||
try:
|
try:
|
||||||
data = request.get_json(force=True)
|
if not request.query:
|
||||||
query = data.get('query')
|
raise HTTPException(status_code=400, detail="缺少query参数")
|
||||||
conversation_context = data.get('conversation_context', "")
|
|
||||||
chat_history = data.get('chat_history', None)
|
|
||||||
previous_slots = data.get('previous_slots', None)
|
|
||||||
if not query:
|
|
||||||
return Response(json.dumps({"error": "缺少query参数"}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=400)
|
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
result = _instance.process_query(query=query,
|
result = await _instance.process_query_async(
|
||||||
conversation_context=conversation_context,
|
query=request.query,
|
||||||
chat_history=chat_history,
|
conversation_context=request.conversation_context,
|
||||||
previous_slots=previous_slots,
|
chat_history=request.chat_history,
|
||||||
|
previous_slots=request.previous_slots,
|
||||||
use_jieba=False,
|
use_jieba=False,
|
||||||
enable_query_expansion=True)
|
enable_query_expansion=True
|
||||||
|
)
|
||||||
|
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
|
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
|
||||||
@@ -67,41 +122,61 @@ def intent_recognize():
|
|||||||
|
|
||||||
# 提取关键词信息
|
# 提取关键词信息
|
||||||
keywords = result["keywords"]
|
keywords = result["keywords"]
|
||||||
keywords_str = ""
|
keywords_list = []
|
||||||
if keywords and keywords.get("terms"):
|
if keywords and keywords.get("terms"):
|
||||||
term_details = []
|
|
||||||
for term in keywords["terms"]:
|
for term in keywords["terms"]:
|
||||||
term_info = {
|
keywords_list.append({
|
||||||
"名称": term["name"],
|
"名称": term["name"]
|
||||||
}
|
})
|
||||||
term_details.append(term_info)
|
|
||||||
keywords_str = term_details
|
|
||||||
|
|
||||||
# 提取槽位填充信息
|
# 提取槽位填充信息
|
||||||
slot_filling = result.get("slot_filling", {})
|
slot_filling = result.get("slot_filling", {})
|
||||||
|
|
||||||
response_result = {
|
# 构建响应
|
||||||
"source_query": query,
|
response = IntentRecognizeResponse(
|
||||||
"source_query_keys": result["query_keys"],
|
source_query=request.query,
|
||||||
"vertical_classification": classification["vertical_classification"],
|
source_query_keys=result["query_keys"],
|
||||||
"sub_classification": classification["sub_classification"],
|
vertical_classification=classification["vertical_classification"],
|
||||||
"rewrite_query": result["rewrite"]["rewrite"],
|
sub_classification=classification["sub_classification"],
|
||||||
"keywords": keywords_str,
|
rewrite_query=result["rewrite"]["rewrite"],
|
||||||
"has_slot_filling": len(slot_filling)!=0,
|
keywords=keywords_list,
|
||||||
"slot_filling": {
|
has_slot_filling=len(slot_filling) != 0,
|
||||||
"is_complete": slot_filling.get("is_complete", False),
|
slot_filling=SlotFillingResponse(
|
||||||
"missing_slots": slot_filling.get("missing_slots", {}),
|
is_complete=slot_filling.get("is_complete", False),
|
||||||
"filled_data": slot_filling.get("filled_data", {})
|
missing_slots=slot_filling.get("missing_slots", {}),
|
||||||
},
|
filled_data=slot_filling.get("filled_data", {})
|
||||||
"query_expand": result["query_expand"]
|
),
|
||||||
}
|
query_expand=QueryExpandResponse(
|
||||||
return Response(json.dumps(response_result, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
all=result["query_expand"]["all"],
|
||||||
|
step_back=result["query_expand"]["step_back"],
|
||||||
|
follow_up=result["query_expand"]["follow_up"],
|
||||||
|
hyde=result["query_expand"]["hyde"],
|
||||||
|
multi_questions=result["query_expand"]["multi_questions"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
|
except HTTPException as e:
|
||||||
|
raise e
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"意图识别出错: {str(e)}", exc_info=True)
|
logger.error(f"意图识别出错: {str(e)}", exc_info=True)
|
||||||
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
raise HTTPException(status_code=500, detail=str(e))
|
||||||
|
|
||||||
|
# 添加健康检查端点
|
||||||
|
@app.get("/health", summary="健康检查")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# 开发环境使用Flask内置服务器
|
# 使用uvicorn启动服务
|
||||||
# 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app
|
import uvicorn
|
||||||
# uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app
|
uvicorn.run(
|
||||||
app.run(host="0.0.0.0", port=8001, threaded=True)
|
"rag2_0.dify.intent_recognition_api:app",
|
||||||
|
host="0.0.0.0",
|
||||||
|
port=8001,
|
||||||
|
reload=False, # 开发环境启用热重载
|
||||||
|
workers=1 # 生产环境可以增加worker数量
|
||||||
|
)
|
||||||
|
# 生产环境可以使用以下命令启动:
|
||||||
|
# uvicorn rag2_0.dify.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 10
|
||||||
@@ -225,8 +225,6 @@ class DataProblemSlots(SlotBase):
|
|||||||
class FileExtensionConsultingSlots(SlotBase):
|
class FileExtensionConsultingSlots(SlotBase):
|
||||||
file_extension: str = Field(default="", description="文件后缀名")
|
file_extension: str = Field(default="", description="文件后缀名")
|
||||||
operation_purpose: str = Field(default="", description="操作目的(了解对应软件,对应工程)")
|
operation_purpose: str = Field(default="", description="操作目的(了解对应软件,对应工程)")
|
||||||
file_source: Optional[str] = Field(default="", description="文件来源场景")
|
|
||||||
related_software: Optional[str] = Field(default="", description="相关软件名称")
|
|
||||||
|
|
||||||
def check_required_slots(self) -> Tuple[bool, Dict[str, str]]:
|
def check_required_slots(self) -> Tuple[bool, Dict[str, str]]:
|
||||||
"""检查必填槽位是否都存在"""
|
"""检查必填槽位是否都存在"""
|
||||||
@@ -239,7 +237,7 @@ class FileExtensionConsultingSlots(SlotBase):
|
|||||||
|
|
||||||
# 3.2 软件锁类
|
# 3.2 软件锁类
|
||||||
class SoftwareLockSlots(SlotBase):
|
class SoftwareLockSlots(SlotBase):
|
||||||
lock_type: str = Field(default="", description="锁类型(单机锁、网络锁)")
|
lock_type: str = Field(default="单机锁", description="锁类型(单机锁、网络锁)")
|
||||||
operation_purpose: str = Field(default="", description="操作目的(查询锁信息、无法激活、无法注册)")
|
operation_purpose: str = Field(default="", description="操作目的(查询锁信息、无法激活、无法注册)")
|
||||||
lock_number: Optional[str] = Field(default="", description="软件锁编号/注册号")
|
lock_number: Optional[str] = Field(default="", description="软件锁编号/注册号")
|
||||||
|
|
||||||
@@ -259,7 +257,6 @@ class InstallationDownloadSlots(SlotBase):
|
|||||||
file_name: str = Field(default="", description="文件名,与software_name二选一")
|
file_name: str = Field(default="", description="文件名,与software_name二选一")
|
||||||
operation_stage: str = Field(default="", description="操作阶段(下载、安装等)")
|
operation_stage: str = Field(default="", description="操作阶段(下载、安装等)")
|
||||||
os_version: Optional[str] = Field(default="", description="操作系统版本")
|
os_version: Optional[str] = Field(default="", description="操作系统版本")
|
||||||
package_source: Optional[str] = Field(default="", description="安装包来源/版本号")
|
|
||||||
|
|
||||||
def check_required_slots(self) -> Tuple[bool, Dict[str, str]]:
|
def check_required_slots(self) -> Tuple[bool, Dict[str, str]]:
|
||||||
"""检查必填槽位是否都存在"""
|
"""检查必填槽位是否都存在"""
|
||||||
|
|||||||
@@ -38,839 +38,6 @@ from .DataModels import (
|
|||||||
from .ProfessionalNounVector import ProfessionalNounRetriever, AsyncProfessionalNounRetriever
|
from .ProfessionalNounVector import ProfessionalNounRetriever, AsyncProfessionalNounRetriever
|
||||||
from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM, SiliconFlowReRankerModel
|
from rag2_0.tool.ModelTool import XinferenceReRankerModel, OpenAiLLM, SiliconFlowReRankerModel
|
||||||
|
|
||||||
|
|
||||||
class IntentRecognizer:
|
|
||||||
"""
|
|
||||||
意图识别和问题改写类
|
|
||||||
"""
|
|
||||||
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo", vector_index_dir: str = None):
|
|
||||||
"""
|
|
||||||
初始化意图识别器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
api_key: OpenAI API密钥,如果为None则从环境变量获取
|
|
||||||
base_url: OpenAI API基础URL,如果为None则使用默认URL
|
|
||||||
model_name: 要使用的模型名称
|
|
||||||
vector_index_dir: 向量索引目录,如果为None则使用默认目录
|
|
||||||
"""
|
|
||||||
# 初始化LLM
|
|
||||||
llm_params = {
|
|
||||||
"temperature": 0.2, # 降低随机性,使结果更确定
|
|
||||||
"top_p": 0.7,
|
|
||||||
"model": model_name
|
|
||||||
}
|
|
||||||
|
|
||||||
# 如果提供了API密钥,则使用提供的密钥
|
|
||||||
if api_key:
|
|
||||||
llm_params["api_key"] = api_key
|
|
||||||
|
|
||||||
# 如果提供了自定义URL,则使用提供的URL
|
|
||||||
if base_url:
|
|
||||||
llm_params["base_url"] = base_url
|
|
||||||
|
|
||||||
self._llm = OpenAiLLM(**llm_params)
|
|
||||||
|
|
||||||
# 加载suffix关键词
|
|
||||||
self._suffix_keywords = self._load_suffix_keywords()
|
|
||||||
|
|
||||||
# 初始化向量检索器
|
|
||||||
self._noun_retriever = ProfessionalNounRetriever(api_key=api_key, index_dir=vector_index_dir)
|
|
||||||
|
|
||||||
def _load_suffix_keywords(self, filepath: str = None) -> List[str]:
|
|
||||||
"""
|
|
||||||
加载后缀关键词列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
filepath: 后缀关键词文件路径,默认为None使用默认路径
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
后缀关键词列表
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
# 如果未指定路径,使用默认路径
|
|
||||||
if filepath is None:
|
|
||||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
||||||
filepath = os.path.join(current_dir, "..", "..", "data", "nouns", "suffix_keywords.json")
|
|
||||||
|
|
||||||
# 读取JSON文件
|
|
||||||
with open(filepath, "r", encoding="utf-8") as f:
|
|
||||||
suffix_data = json.load(f)
|
|
||||||
|
|
||||||
# 添加额外的固定后缀
|
|
||||||
return suffix_data
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"加载后缀关键词失败: {e}") from e
|
|
||||||
|
|
||||||
def _classify_intent(self, query: str, conversation_context: str = "",
|
|
||||||
chat_history: List[Dict[str, str]] = None,
|
|
||||||
previous_slots: Dict[str, Any] = None) -> Classification:
|
|
||||||
"""
|
|
||||||
对用户输入进行意图分类
|
|
||||||
|
|
||||||
Args:
|
|
||||||
content: 用户输入内容
|
|
||||||
keywords: 匹配到的关键词列表
|
|
||||||
rewrite: 重写的问题
|
|
||||||
Returns:
|
|
||||||
分类结果
|
|
||||||
"""
|
|
||||||
classification_start_time = time.time()
|
|
||||||
classification_parser = PydanticOutputParser(pydantic_object=Classification)
|
|
||||||
formatted_prompt = classification_prompt.format(user_input=query,
|
|
||||||
classification_info=classification_info,
|
|
||||||
output_format=classification_parser.get_format_instructions(),
|
|
||||||
conversation_context=conversation_context,
|
|
||||||
chat_history=json.dumps(chat_history, ensure_ascii=False))
|
|
||||||
# 解析输出
|
|
||||||
try:
|
|
||||||
# 调用LLM
|
|
||||||
response = self._llm.invoke(formatted_prompt, False)
|
|
||||||
|
|
||||||
classification_end_time = time.time()
|
|
||||||
classification_time = classification_end_time - classification_start_time
|
|
||||||
logging.info(f"意图分类耗时统计 - 总耗时: {classification_time:.2f}秒")
|
|
||||||
|
|
||||||
# 尝试直接解析JSON响应
|
|
||||||
parsed_output = classification_parser.parse(response.content.strip())
|
|
||||||
return parsed_output
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"解析分类结果时出错: {e}") from e
|
|
||||||
|
|
||||||
def _tokenize_with_jieba(self, query: str) -> List[str]:
|
|
||||||
"""
|
|
||||||
使用jieba分词器对查询进行分词
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户查询
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
分词后的词语列表
|
|
||||||
"""
|
|
||||||
# 使用jieba进行分词
|
|
||||||
seg_list = jieba.cut(query, cut_all=False)
|
|
||||||
|
|
||||||
# 过滤掉停用词和标点符号
|
|
||||||
filtered_tokens = []
|
|
||||||
for token in seg_list:
|
|
||||||
# 过滤掉空格和标点符号
|
|
||||||
if token.strip() and not re.match(r'^[^\w\s]$', token):
|
|
||||||
filtered_tokens.append(token)
|
|
||||||
|
|
||||||
return filtered_tokens
|
|
||||||
|
|
||||||
def _extract_keywords_with_llm(self, query: str, use_jieba: bool = False) -> List[Term]:
|
|
||||||
"""
|
|
||||||
使用LLM从用户查询中提取专业关键词
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户查询
|
|
||||||
use_jieba: 是否使用jieba分词辅助提取关键词
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
提取的术语列表
|
|
||||||
"""
|
|
||||||
# 如果使用jieba分词
|
|
||||||
if use_jieba:
|
|
||||||
# 先使用jieba分词
|
|
||||||
tokens = self._tokenize_with_jieba(query)
|
|
||||||
|
|
||||||
# 构建术语列表
|
|
||||||
terms = []
|
|
||||||
for token in tokens:
|
|
||||||
if len(token) > 1: # 过滤掉单字词
|
|
||||||
terms.append(Term(name=token, synonymous=[], description=""))
|
|
||||||
|
|
||||||
return terms
|
|
||||||
else:
|
|
||||||
# 使用LLM提取关键词
|
|
||||||
# 准备提示词
|
|
||||||
formatted_prompt = extract_nouns_prompt.replace("{content}", query)
|
|
||||||
terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
|
||||||
formatted_prompt = formatted_prompt.replace("{output_format}", terms_list_parser.get_format_instructions())
|
|
||||||
|
|
||||||
# 调用LLM
|
|
||||||
response = self._llm.invoke(formatted_prompt, False)
|
|
||||||
|
|
||||||
# 尝试使用Pydantic解析器解析TermList
|
|
||||||
parsed_output = terms_list_parser.parse(response.content)
|
|
||||||
return parsed_output.terms
|
|
||||||
|
|
||||||
|
|
||||||
def _rerank_matched_terms(self, query_key: str, matched_terms: set, top_k: int = 2, rerank_score:float = 0.6) -> List[Term]:
|
|
||||||
"""
|
|
||||||
对召回的专业术语进行重排序,按与用户查询的相关性排序
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户查询
|
|
||||||
matched_terms: 匹配到的专业术语集合
|
|
||||||
query_keys: 用户查询中提取的关键词列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
重排序后的专业术语列表
|
|
||||||
"""
|
|
||||||
if not matched_terms:
|
|
||||||
return []
|
|
||||||
|
|
||||||
if len(matched_terms) <= top_k:
|
|
||||||
return list(matched_terms)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 将每个术语转换为可用于重排序的文本表示
|
|
||||||
# term_texts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) + "|" + "描述:" + term.description for term in matched_terms]
|
|
||||||
term_texts = ["名称:" + term.name + "|" + "同义词:" + ";".join(term.synonymous) for term in matched_terms]
|
|
||||||
|
|
||||||
# 使用重排序模型
|
|
||||||
xinference_reranker = XinferenceReRankerModel()
|
|
||||||
rerank_results = xinference_reranker.rerank(query_key, term_texts, top_k=top_k)
|
|
||||||
|
|
||||||
# 将matched_terms转换为列表以便按索引访问
|
|
||||||
matched_terms_list = list(matched_terms)
|
|
||||||
|
|
||||||
# 根据重排序结果获取排序后的术语列表
|
|
||||||
reranked_terms = [matched_terms_list[result["index"]] for result in rerank_results if result["score"] >= rerank_score]
|
|
||||||
|
|
||||||
return reranked_terms
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"_rerank_matched_terms重排失败:{e}") from e
|
|
||||||
|
|
||||||
def _match_keywords(self, query: str, use_jieba: bool = False) -> Tuple[TermList, List[str]]:
|
|
||||||
"""
|
|
||||||
从用户问题中匹配关键词,结合LLM提取和向量检索
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户问题
|
|
||||||
use_jieba: 是否使用jieba分词辅助提取关键词
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
匹配到的关键词列表
|
|
||||||
"""
|
|
||||||
start_time = time.time()
|
|
||||||
query_keys=[]
|
|
||||||
# 步骤1: 使用LLM提取查询中的关键词
|
|
||||||
try:
|
|
||||||
llm_start_time = time.time()
|
|
||||||
extracted_terms = self._extract_keywords_with_llm(query, use_jieba)
|
|
||||||
for term in extracted_terms:
|
|
||||||
query_keys.append(term.name)
|
|
||||||
llm_end_time = time.time()
|
|
||||||
llm_time = llm_end_time - llm_start_time
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"LLM关键词提取失败: {e}") from e
|
|
||||||
|
|
||||||
matched_terms = [] # 存储匹配到的Term对象
|
|
||||||
# 步骤2: 使用向量检索找到相似的专业名词
|
|
||||||
try:
|
|
||||||
vector_start_time = time.time()
|
|
||||||
# 对matched_terms中的每个关键字进行向量检索
|
|
||||||
for current_key in query_keys:
|
|
||||||
vector_results = self._noun_retriever.query(current_key, top_k=5, use_intersection=False)
|
|
||||||
current_key_terms = set()
|
|
||||||
# 添加向量检索结果
|
|
||||||
for result in vector_results:
|
|
||||||
if isinstance(result.get('synonymous', []), str):
|
|
||||||
result['synonymous'] = result['synonymous'].split(';')
|
|
||||||
term = Term(
|
|
||||||
name=result.get('name'),
|
|
||||||
synonymous=result.get('synonymous', []),
|
|
||||||
description=result.get('description', '')
|
|
||||||
)
|
|
||||||
current_key_terms.add(term)
|
|
||||||
if len(current_key_terms) > 0:
|
|
||||||
reranked_terms = self._rerank_matched_terms(current_key, current_key_terms)
|
|
||||||
matched_terms.extend(reranked_terms)
|
|
||||||
vector_end_time = time.time()
|
|
||||||
vector_time = vector_end_time - vector_start_time
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"向量检索关键词时出错: {e}") from e
|
|
||||||
|
|
||||||
# 提取所有Term对象的名称并排序
|
|
||||||
# 将set类型的matched_terms转换为TermList类型
|
|
||||||
term_list = TermList(terms=list(matched_terms))
|
|
||||||
end_time = time.time()
|
|
||||||
total_time = end_time - start_time
|
|
||||||
|
|
||||||
# 输出整合的时间日志
|
|
||||||
logging.info(f"关键词匹配耗时统计 - 总耗时: {total_time:.2f}秒, 问题关键词提取: {llm_time:.2f}秒, 向量检索+重排序: {vector_time:.2f}秒")
|
|
||||||
|
|
||||||
return term_list, query_keys
|
|
||||||
|
|
||||||
def _rewrite_query(self, query: str, keywords: TermList, query_keys:List[str], chat_history: List[Dict[str, str]] = None, context: str = "") -> QueryRewrite:
|
|
||||||
"""
|
|
||||||
对用户问题进行改写
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户原始问题
|
|
||||||
keywords: 匹配到的关键词列表
|
|
||||||
query_keys: 用户查询中提取的关键词列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
改写结果
|
|
||||||
"""
|
|
||||||
|
|
||||||
rewrite_start_time = time.time()
|
|
||||||
# 准备问题改写提示
|
|
||||||
terms_dict = [term.model_dump(exclude={"description"}) for term in keywords.terms]
|
|
||||||
# terms_dict = [term.model_dump() for term in keywords.terms]
|
|
||||||
keywords_str = json.dumps(terms_dict, ensure_ascii=False)
|
|
||||||
query_rewrite_parser = PydanticOutputParser(pydantic_object=QueryRewrite)
|
|
||||||
# formatted_prompt = query_rewrite_prompt.format(query=query,
|
|
||||||
# output_format=query_rewrite_parser.get_format_instructions(),
|
|
||||||
# keywords=keywords_str)
|
|
||||||
formatted_prompt = query_rewrite_prompt_pro.format(query=query,
|
|
||||||
output_format=query_rewrite_parser.get_format_instructions(),
|
|
||||||
keywords=keywords_str,
|
|
||||||
chat_history=chat_history,
|
|
||||||
context=context)
|
|
||||||
# 解析输出
|
|
||||||
try:
|
|
||||||
# 调用LLM
|
|
||||||
response = self._llm.invoke(formatted_prompt, False)
|
|
||||||
|
|
||||||
# 尝试直接解析JSON响应
|
|
||||||
parsed_output = query_rewrite_parser.parse(response.content)
|
|
||||||
rewrite_end_time = time.time()
|
|
||||||
rewrite_time = rewrite_end_time - rewrite_start_time
|
|
||||||
logging.info(f"问题改写耗时统计 - 总耗时: {rewrite_time:.2f}秒")
|
|
||||||
return parsed_output
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"解析问题改写结果时出错: {e}") from e
|
|
||||||
|
|
||||||
def _judge_define_suffix(self, input_str: str) -> Tuple[bool, List[str]]:
|
|
||||||
"""
|
|
||||||
判断输入字符串是否包含定义的后缀,并返回所有匹配到的后缀名列表
|
|
||||||
|
|
||||||
Args:
|
|
||||||
input_str: 输入字符串
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[bool, List[str]]: (是否包含定义的后缀, 匹配到的后缀名列表)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 构建正则表达式模式,匹配大小写不敏感且前面可能带有.
|
|
||||||
pattern = r'(?:\.?)(' + '|'.join(re.escape(field.get('name')) for field in self._suffix_keywords) + r')'
|
|
||||||
|
|
||||||
# 使用 re.IGNORECASE 标志来忽略大小写,findall找到所有匹配
|
|
||||||
matches = re.finditer(pattern, input_str, re.IGNORECASE)
|
|
||||||
matched_suffixes = [match.group(1) for match in matches]
|
|
||||||
|
|
||||||
return bool(matched_suffixes), matched_suffixes
|
|
||||||
|
|
||||||
def process_query(self, query: str, conversation_context: str = "",
|
|
||||||
chat_history: List[Dict[str, str]] = None,
|
|
||||||
previous_slots: Dict[str, Any] = None,
|
|
||||||
use_jieba: bool = False,
|
|
||||||
enable_query_expansion: bool = False) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
处理用户问题的完整流程
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户原始问题
|
|
||||||
conversation_context: 会话背景信息
|
|
||||||
chat_history: 历史对话记录,格式为[{"user": "content"}, {"assistant": "content"}]
|
|
||||||
previous_slots: 历史槽位信息
|
|
||||||
use_jieba: 是否使用jieba分词辅助提取关键词
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含分类、关键词、改写和槽位填充结果的字典
|
|
||||||
"""
|
|
||||||
# 是否是扩展名
|
|
||||||
# is_suffix, matched_suffixes = self._judge_define_suffix(query)
|
|
||||||
# if is_suffix:
|
|
||||||
# # 将所有匹配到的后缀名作为Term添加到结果中
|
|
||||||
# suffix_terms = []
|
|
||||||
# for suffix in matched_suffixes:
|
|
||||||
# term_dict = next((item for item in self._suffix_keywords if item['name'].lower() == suffix.lower()), None)
|
|
||||||
# if term_dict:
|
|
||||||
# suffix_term = Term(
|
|
||||||
# name=term_dict.get('name'),
|
|
||||||
# synonymous=term_dict.get('synonymous', []),
|
|
||||||
# description=json.dumps(term_dict.get('description', ''), ensure_ascii=False)
|
|
||||||
# )
|
|
||||||
# suffix_terms.append(suffix_term)
|
|
||||||
|
|
||||||
# return Classification(vertical_classification="安装下载", sub_classification="查询"), TermList(terms=suffix_terms), QueryRewrite(rewrite=query), matched_suffixes
|
|
||||||
|
|
||||||
if chat_history is None:
|
|
||||||
chat_history = []
|
|
||||||
if previous_slots is None:
|
|
||||||
previous_slots = {}
|
|
||||||
|
|
||||||
# 步骤: 并行执行提问扩展
|
|
||||||
if enable_query_expansion:
|
|
||||||
# 创建线程和结果容器
|
|
||||||
threads_and_results = [
|
|
||||||
# 5.1: 后退提示
|
|
||||||
self._run_in_thread(self._generate_step_back_prompt, args=(query, chat_history, conversation_context)),
|
|
||||||
|
|
||||||
# 5.2: Follow Up Questions
|
|
||||||
self._run_in_thread(self._generate_follow_up_questions, args=(query, chat_history, conversation_context)),
|
|
||||||
|
|
||||||
# 5.3: HyDE
|
|
||||||
self._run_in_thread(self._generate_hypothetical_document, args=(query, chat_history, conversation_context)),
|
|
||||||
|
|
||||||
# 5.4: 多问题查询
|
|
||||||
self._run_in_thread(self._generate_multi_questions, args=(query, chat_history, conversation_context))
|
|
||||||
]
|
|
||||||
|
|
||||||
# 步骤1: 匹配关键词
|
|
||||||
keywords_terms, query_keys = self._match_keywords(query, use_jieba)
|
|
||||||
|
|
||||||
# 步骤2: 问题改写
|
|
||||||
rewrite = self._rewrite_query(
|
|
||||||
query=query,
|
|
||||||
keywords=keywords_terms,
|
|
||||||
query_keys=query_keys,
|
|
||||||
chat_history=chat_history,
|
|
||||||
context=conversation_context
|
|
||||||
)
|
|
||||||
|
|
||||||
# 步骤3: 进行意图识别和槽位填充
|
|
||||||
# result = self._process_intent_and_slot(rewrite.rewrite, conversation_context, chat_history, previous_slots)
|
|
||||||
# result.update({"keywords": keywords_terms.model_dump(),
|
|
||||||
# "rewrite": rewrite.model_dump(),
|
|
||||||
# "query_keys": query_keys})
|
|
||||||
# return result
|
|
||||||
# 步骤3: 进行意图分类
|
|
||||||
classification = self._classify_intent(rewrite.rewrite, conversation_context, chat_history, previous_slots)
|
|
||||||
|
|
||||||
# 步骤4: 进行槽位填充
|
|
||||||
# 如果是有效分类,进行槽位填充
|
|
||||||
slot_filling_result = {}
|
|
||||||
if classification.vertical_classification not in ["其他", "闲聊"] and classification.sub_classification not in ["其他", "闲聊"]:
|
|
||||||
slot_filling_result = self._fill_slots(rewrite.rewrite, classification, conversation_context, chat_history, previous_slots)
|
|
||||||
|
|
||||||
if not enable_query_expansion:
|
|
||||||
return {
|
|
||||||
"classification": classification.model_dump(),
|
|
||||||
"keywords": keywords_terms.model_dump(),
|
|
||||||
"rewrite": rewrite.model_dump(),
|
|
||||||
"query_keys": query_keys,
|
|
||||||
"slot_filling": slot_filling_result
|
|
||||||
}
|
|
||||||
|
|
||||||
# 等待所有线程完成
|
|
||||||
start_time = time.time()
|
|
||||||
for thread, _ in threads_and_results:
|
|
||||||
thread.join()
|
|
||||||
end_time = time.time()
|
|
||||||
logging.info(f"问题扩展环节耗时统计 - 总耗时: {end_time - start_time:.2f}秒")
|
|
||||||
|
|
||||||
# 收集结果
|
|
||||||
step_back_result = threads_and_results[0][1][0] if threads_and_results[0][1] else StepBackPrompt(original_query=query, step_back_query=query)
|
|
||||||
follow_up_result = threads_and_results[1][1][0] if threads_and_results[1][1] else FollowUpQuestions(original_query=query, follow_up_query=query)
|
|
||||||
hyde_result = threads_and_results[2][1][0] if threads_and_results[2][1] else HypotheticalDocument(original_query=query, hypothetical_answer="")
|
|
||||||
multi_questions_result = threads_and_results[3][1][0] if threads_and_results[3][1] else MultiQuestions(original_query=query, sub_questions=[query])
|
|
||||||
all_questions=multi_questions_result.sub_questions
|
|
||||||
all_questions.append(query)
|
|
||||||
all_questions.append(step_back_result.step_back_query)
|
|
||||||
all_questions.append(follow_up_result.follow_up_query)
|
|
||||||
all_questions.append(hyde_result.hypothetical_answer)
|
|
||||||
all_questions = list(set(all_questions))
|
|
||||||
|
|
||||||
query_expand={"all":all_questions,
|
|
||||||
"step_back":step_back_result.model_dump(),
|
|
||||||
"follow_up":follow_up_result.model_dump(),
|
|
||||||
"hyde":hyde_result.model_dump(),
|
|
||||||
"multi_questions":multi_questions_result.model_dump()}
|
|
||||||
# 返回所有结果
|
|
||||||
return {
|
|
||||||
"classification": classification.model_dump(),
|
|
||||||
"keywords": keywords_terms.model_dump(),
|
|
||||||
"rewrite": rewrite.model_dump(),
|
|
||||||
"query_keys": query_keys,
|
|
||||||
"slot_filling": slot_filling_result,
|
|
||||||
"query_expand": query_expand
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _fill_slots(self, query: str, classification: Classification, conversation_context: str = "",
|
|
||||||
chat_history: List[Dict[str, str]] = None,
|
|
||||||
previous_slots: Dict[str, Any] = None,) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
根据分类结果对问题进行槽位填充
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户原始问题
|
|
||||||
classification: 意图分类结果
|
|
||||||
keywords: 匹配的关键词列表
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
填充后的槽位数据模型
|
|
||||||
"""
|
|
||||||
# 根据分类结果选择对应的数据模型
|
|
||||||
slot_model = self._get_slot_model(classification)
|
|
||||||
if not slot_model:
|
|
||||||
raise RuntimeError("未找到匹配的槽位模型")
|
|
||||||
fill_slots_start_time = time.time()
|
|
||||||
# 使用LLM进行槽位填充
|
|
||||||
filled_slots = self._fill_slots_with_llm(query, classification, slot_model, conversation_context, chat_history, previous_slots)
|
|
||||||
fill_slots_end_time = time.time()
|
|
||||||
fill_slots_time = fill_slots_end_time - fill_slots_start_time
|
|
||||||
logging.info(f"槽位填充耗时统计 - 总耗时: {fill_slots_time:.2f}秒")
|
|
||||||
|
|
||||||
# 检查必填槽位是否都已填充
|
|
||||||
is_complete, missing_slots = filled_slots.check_required_slots()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"is_complete": is_complete,
|
|
||||||
"missing_slots": missing_slots,
|
|
||||||
"filled_data": filled_slots.model_dump()
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_slot_model(self, classification: Classification) -> Optional[type]:
|
|
||||||
"""
|
|
||||||
根据分类结果获取对应的槽位模型类,用于统一提示词处理
|
|
||||||
|
|
||||||
Args:
|
|
||||||
classification: 意图分类结果
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
对应的槽位模型类
|
|
||||||
"""
|
|
||||||
# 软件问题
|
|
||||||
if classification.vertical_classification == "软件问题":
|
|
||||||
if classification.sub_classification == "软件功能":
|
|
||||||
return SoftwareFunctionSlots
|
|
||||||
elif classification.sub_classification == "故障排查":
|
|
||||||
return SoftwareTroubleShootingSlots
|
|
||||||
|
|
||||||
# 业务问题
|
|
||||||
elif classification.vertical_classification == "业务问题":
|
|
||||||
if classification.sub_classification == "专业咨询":
|
|
||||||
return ProfessionalConsultingSlots
|
|
||||||
elif classification.sub_classification == "数据问题":
|
|
||||||
return DataProblemSlots
|
|
||||||
|
|
||||||
# 安装下载注册
|
|
||||||
elif classification.vertical_classification == "安装下载注册":
|
|
||||||
if classification.sub_classification == "后缀名咨询":
|
|
||||||
return FileExtensionConsultingSlots
|
|
||||||
elif classification.sub_classification == "软件锁类":
|
|
||||||
return SoftwareLockSlots
|
|
||||||
elif classification.sub_classification == "安装下载类":
|
|
||||||
return InstallationDownloadSlots
|
|
||||||
elif classification.sub_classification == "问题排查类":
|
|
||||||
return ProblemDiagnosisSlots
|
|
||||||
|
|
||||||
# 其他
|
|
||||||
elif classification.vertical_classification == "其他":
|
|
||||||
return OtherSlots
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
def _fill_slots_with_llm(self, query: str,
|
|
||||||
classification: Classification,
|
|
||||||
slot_model_class: type,
|
|
||||||
conversation_context: str = "",
|
|
||||||
chat_history: List[Dict[str, str]] = None,
|
|
||||||
previous_slots: Dict[str, Any] = None) -> Any:
|
|
||||||
"""
|
|
||||||
使用LLM进行槽位填充
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户原始问题
|
|
||||||
classification: 意图分类结果
|
|
||||||
slot_model_class: 槽位模型类
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
填充后的槽位数据模型实例
|
|
||||||
"""
|
|
||||||
# 准备提示词
|
|
||||||
slot_parser = PydanticOutputParser(pydantic_object=slot_model_class)
|
|
||||||
|
|
||||||
formatted_prompt = slot_filling_prompt.format(
|
|
||||||
query=query,
|
|
||||||
vertical_classification=classification.vertical_classification,
|
|
||||||
sub_classification=classification.sub_classification,
|
|
||||||
output_format=slot_parser.get_format_instructions(),
|
|
||||||
conversation_context=conversation_context,
|
|
||||||
chat_history=json.dumps(chat_history,ensure_ascii=False),
|
|
||||||
previous_slots=json.dumps(previous_slots,ensure_ascii=False),
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
# 调用LLM
|
|
||||||
response = self._llm.invoke(formatted_prompt, False)
|
|
||||||
|
|
||||||
|
|
||||||
# 尝试解析LLM响应
|
|
||||||
parsed_output = slot_parser.parse(response.content)
|
|
||||||
return parsed_output
|
|
||||||
except Exception as e:
|
|
||||||
# 如果解析失败,创建一个空的模型实例
|
|
||||||
empty_instance = slot_model_class()
|
|
||||||
return empty_instance
|
|
||||||
|
|
||||||
def _generate_step_back_prompt(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> StepBackPrompt:
|
|
||||||
"""
|
|
||||||
生成后退提示
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户原始问题
|
|
||||||
chat_history: 历史对话记录
|
|
||||||
conversation_context: 会话背景信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
后退提示结果
|
|
||||||
"""
|
|
||||||
step_back_start_time = time.time()
|
|
||||||
# 准备提示词
|
|
||||||
step_back_parser = PydanticOutputParser(pydantic_object=StepBackPrompt)
|
|
||||||
formatted_prompt = step_back_prompt.format(
|
|
||||||
query=query,
|
|
||||||
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
|
|
||||||
conversation_context=conversation_context,
|
|
||||||
output_format=step_back_parser.get_format_instructions()
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 调用LLM
|
|
||||||
response = self._llm.invoke(formatted_prompt, False)
|
|
||||||
|
|
||||||
# 解析输出
|
|
||||||
parsed_output = step_back_parser.parse(response.content)
|
|
||||||
step_back_end_time = time.time()
|
|
||||||
step_back_time = step_back_end_time - step_back_start_time
|
|
||||||
logging.debug(f"后退提示生成耗时统计 - 总耗时: {step_back_time:.2f}秒")
|
|
||||||
return parsed_output
|
|
||||||
except Exception as e:
|
|
||||||
# 如果解析失败,返回原始查询作为后退提示
|
|
||||||
logging.error(f"后退提示生成失败: {e}", exc_info=True)
|
|
||||||
return StepBackPrompt(original_query=query, step_back_query=query)
|
|
||||||
|
|
||||||
def _generate_follow_up_questions(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> FollowUpQuestions:
|
|
||||||
"""
|
|
||||||
生成后续问题
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户原始问题
|
|
||||||
chat_history: 历史对话记录
|
|
||||||
conversation_context: 会话背景信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
后续问题结果
|
|
||||||
"""
|
|
||||||
follow_up_start_time = time.time()
|
|
||||||
# 准备提示词
|
|
||||||
follow_up_parser = PydanticOutputParser(pydantic_object=FollowUpQuestions)
|
|
||||||
formatted_prompt = follow_up_questions_prompt.format(
|
|
||||||
query=query,
|
|
||||||
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
|
|
||||||
conversation_context=conversation_context,
|
|
||||||
output_format=follow_up_parser.get_format_instructions()
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 调用LLM
|
|
||||||
response = self._llm.invoke(formatted_prompt, False)
|
|
||||||
|
|
||||||
# 解析输出
|
|
||||||
parsed_output = follow_up_parser.parse(response.content)
|
|
||||||
follow_up_end_time = time.time()
|
|
||||||
follow_up_time = follow_up_end_time - follow_up_start_time
|
|
||||||
logging.debug(f"后续问题生成耗时统计 - 总耗时: {follow_up_time:.2f}秒")
|
|
||||||
return parsed_output
|
|
||||||
except Exception as e:
|
|
||||||
# 如果解析失败,返回原始查询作为后续问题
|
|
||||||
logging.error(f"后续问题生成失败: {e}", exc_info=True)
|
|
||||||
return FollowUpQuestions(original_query=query, follow_up_query=query)
|
|
||||||
|
|
||||||
def _generate_hypothetical_document(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> HypotheticalDocument:
|
|
||||||
"""
|
|
||||||
生成假设性文档
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户原始问题
|
|
||||||
chat_history: 历史对话记录
|
|
||||||
conversation_context: 会话背景信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
假设性文档结果
|
|
||||||
"""
|
|
||||||
hyde_start_time = time.time()
|
|
||||||
# 准备提示词
|
|
||||||
hyde_parser = PydanticOutputParser(pydantic_object=HypotheticalDocument)
|
|
||||||
formatted_prompt = hyde_prompt.format(
|
|
||||||
query=query,
|
|
||||||
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
|
|
||||||
conversation_context=conversation_context,
|
|
||||||
output_format=hyde_parser.get_format_instructions()
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 调用LLM
|
|
||||||
response = self._llm.invoke(formatted_prompt, False)
|
|
||||||
|
|
||||||
# 解析输出
|
|
||||||
parsed_output = hyde_parser.parse(response.content)
|
|
||||||
hyde_end_time = time.time()
|
|
||||||
hyde_time = hyde_end_time - hyde_start_time
|
|
||||||
logging.debug(f"假设性文档生成耗时统计 - 总耗时: {hyde_time:.2f}秒")
|
|
||||||
return parsed_output
|
|
||||||
except Exception as e:
|
|
||||||
# 如果解析失败,返回空的假设性回答
|
|
||||||
logging.error(f"假设性文档生成失败: {e}", exc_info=True)
|
|
||||||
return HypotheticalDocument(original_query=query, hypothetical_answer="")
|
|
||||||
|
|
||||||
def _generate_multi_questions(self, query: str, chat_history: List[Dict[str, str]] = None, conversation_context: str = "") -> MultiQuestions:
|
|
||||||
"""
|
|
||||||
生成多角度问题
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户原始问题
|
|
||||||
chat_history: 历史对话记录
|
|
||||||
conversation_context: 会话背景信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
多角度问题结果
|
|
||||||
"""
|
|
||||||
multi_questions_start_time = time.time()
|
|
||||||
# 准备提示词
|
|
||||||
multi_questions_parser = PydanticOutputParser(pydantic_object=MultiQuestions)
|
|
||||||
formatted_prompt = multi_questions_prompt.format(
|
|
||||||
query=query,
|
|
||||||
chat_history=json.dumps(chat_history, ensure_ascii=False) if chat_history else "[]",
|
|
||||||
conversation_context=conversation_context,
|
|
||||||
output_format=multi_questions_parser.get_format_instructions()
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 调用LLM
|
|
||||||
response = self._llm.invoke(formatted_prompt, False)
|
|
||||||
|
|
||||||
# 解析输出
|
|
||||||
parsed_output = multi_questions_parser.parse(response.content)
|
|
||||||
multi_questions_end_time = time.time()
|
|
||||||
multi_questions_time = multi_questions_end_time - multi_questions_start_time
|
|
||||||
logging.debug(f"多角度问题生成耗时统计 - 总耗时: {multi_questions_time:.2f}秒")
|
|
||||||
return parsed_output
|
|
||||||
except Exception as e:
|
|
||||||
# 如果解析失败,返回原始查询作为唯一子问题
|
|
||||||
logging.error(f"多角度问题生成失败: {e}",exc_info=True)
|
|
||||||
return MultiQuestions(original_query=query, sub_questions=[query])
|
|
||||||
|
|
||||||
def _run_in_thread(self, func, args=(), kwargs={}):
|
|
||||||
"""
|
|
||||||
在线程中执行函数并返回结果
|
|
||||||
|
|
||||||
Args:
|
|
||||||
func: 要执行的函数
|
|
||||||
args: 函数的位置参数
|
|
||||||
kwargs: 函数的关键字参数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(thread, result_container): 线程对象和存放结果的容器
|
|
||||||
"""
|
|
||||||
result_container = []
|
|
||||||
|
|
||||||
def thread_target():
|
|
||||||
try:
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
result_container.append(result)
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"线程执行函数 {func.__name__} 时出错: {e}", exc_info=True)
|
|
||||||
result_container.append(None)
|
|
||||||
|
|
||||||
thread = threading.Thread(target=thread_target)
|
|
||||||
thread.start()
|
|
||||||
return thread, result_container
|
|
||||||
|
|
||||||
|
|
||||||
def _process_intent_and_slot(self, user_input: str, conversation_context: str = "",
|
|
||||||
chat_history: List[Dict[str, str]] = None,
|
|
||||||
previous_slots: Dict[str, Any] = None) -> Dict[str, Any]:
|
|
||||||
"""
|
|
||||||
使用统一提示词同时进行意图识别和槽位填充
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_input: 当前用户输入
|
|
||||||
conversation_context: 会话背景信息
|
|
||||||
chat_history: 历史对话记录,格式为[{"role": "user", "content": "..."}, {"role": "assistant", "content": "..."}]
|
|
||||||
previous_slots: 历史槽位信息
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
包含意图分类和槽位填充结果的字典
|
|
||||||
"""
|
|
||||||
# 初始化默认值
|
|
||||||
if chat_history is None:
|
|
||||||
chat_history = []
|
|
||||||
|
|
||||||
if previous_slots is None:
|
|
||||||
previous_slots = {}
|
|
||||||
|
|
||||||
# 生成槽位映射文档
|
|
||||||
slot_mapping_doc = generate_slot_mapping_doc()
|
|
||||||
|
|
||||||
# 准备提示词
|
|
||||||
parser = PydanticOutputParser(pydantic_object=IntentAndSlotResult)
|
|
||||||
formatted_prompt = intent_and_slot_prompt.format(
|
|
||||||
conversation_context=conversation_context,
|
|
||||||
chat_history=json.dumps(chat_history, ensure_ascii=False),
|
|
||||||
previous_slots=json.dumps(previous_slots, ensure_ascii=False),
|
|
||||||
user_input=user_input,
|
|
||||||
slot_mapping_doc=slot_mapping_doc,
|
|
||||||
output_format=parser.get_format_instructions(),
|
|
||||||
classification_info=classification_info
|
|
||||||
)
|
|
||||||
|
|
||||||
# 调用LLM
|
|
||||||
llm_start_time = time.time()
|
|
||||||
response = self._llm.invoke(formatted_prompt + output_example, False)
|
|
||||||
llm_end_time = time.time()
|
|
||||||
llm_time = llm_end_time - llm_start_time
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
# 解析LLM响应为JSON
|
|
||||||
result_json = parser.parse(response.content)
|
|
||||||
classification=result_json.classification
|
|
||||||
slot_filling=result_json.slots
|
|
||||||
is_complete, missing_slots = slot_filling.check_required_slots()
|
|
||||||
expected_slot_model = self._get_slot_model(classification)
|
|
||||||
|
|
||||||
# 添加容错处理,发生概率较低,但仍需处理
|
|
||||||
if expected_slot_model is None:
|
|
||||||
# 添加容错处理,应对LLM返回错误分类信息,一级分类跟二级分类错乱
|
|
||||||
# 重新分类
|
|
||||||
classification = self._classify_intent(user_input, conversation_context, chat_history, previous_slots)
|
|
||||||
fill_slots = self._fill_slots(user_input, classification, conversation_context, chat_history, previous_slots)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"classification": classification.model_dump(),
|
|
||||||
"slot_filling": fill_slots
|
|
||||||
}
|
|
||||||
logging.warning(f"重新分类与槽点填充")
|
|
||||||
return result
|
|
||||||
elif expected_slot_model.__name__ != type(slot_filling).__name__:
|
|
||||||
# 添加容错处理,应对LLM槽位与分类不匹配。重新填充槽位
|
|
||||||
slot_filling = self._fill_slots(user_input, classification, conversation_context, chat_history, previous_slots)
|
|
||||||
result = {
|
|
||||||
"classification": classification.model_dump(),
|
|
||||||
"slot_filling": slot_filling
|
|
||||||
}
|
|
||||||
logging.warning(f"重新填充槽点")
|
|
||||||
return result
|
|
||||||
|
|
||||||
logging.info(f"意图识别+槽位LLM调用耗时: {llm_time:.2f}秒")
|
|
||||||
|
|
||||||
# 构建最终结果
|
|
||||||
result = {
|
|
||||||
"classification": classification.model_dump(),
|
|
||||||
"slot_filling": {
|
|
||||||
"is_complete": is_complete,
|
|
||||||
"missing_slots": missing_slots,
|
|
||||||
"filled_data": slot_filling.model_dump()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"process_intent_and_slot error:{e}") from e
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncIntentRecognizer:
|
class AsyncIntentRecognizer:
|
||||||
"""
|
"""
|
||||||
异步意图识别和问题改写类
|
异步意图识别和问题改写类
|
||||||
@@ -976,7 +143,7 @@ class AsyncIntentRecognizer:
|
|||||||
formatted_prompt = classification_prompt.format(user_input=query,
|
formatted_prompt = classification_prompt.format(user_input=query,
|
||||||
classification_info=classification_info,
|
classification_info=classification_info,
|
||||||
output_format=classification_parser.get_format_instructions(),
|
output_format=classification_parser.get_format_instructions(),
|
||||||
conversation_context=conversation_context,
|
# conversation_context=conversation_context,
|
||||||
chat_history=json.dumps(chat_history, ensure_ascii=False))
|
chat_history=json.dumps(chat_history, ensure_ascii=False))
|
||||||
# 解析输出
|
# 解析输出
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -56,12 +56,9 @@ classification_info="""【垂直领域分类】:
|
|||||||
1. 其他"""
|
1. 其他"""
|
||||||
|
|
||||||
classification_prompt="""
|
classification_prompt="""
|
||||||
用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容,将其归类为以下垂直领域之一:
|
用户正在使用电力造价软件或想询问电力造价领域相关知识,你需要根据用户的输入内容集合历史对话(如果存在),将其归类为以下垂直领域之一:
|
||||||
{classification_info}
|
{classification_info}
|
||||||
|
|
||||||
## 【会话背景信息】
|
|
||||||
{conversation_context}
|
|
||||||
|
|
||||||
## 【历史对话记录】
|
## 【历史对话记录】
|
||||||
{chat_history}
|
{chat_history}
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
|
|
||||||
from .ProfessionalNounVector import ProfessionalNounVectorizer, ProfessionalNounRetriever
|
from .ProfessionalNounVector import ProfessionalNounVectorizer, ProfessionalNounRetriever
|
||||||
from .IntentRecognition import IntentRecognizer, AsyncIntentRecognizer
|
from .IntentRecognition import AsyncIntentRecognizer
|
||||||
from .DataModels import Term, TermList, Classification, QueryRewrite
|
from .DataModels import Term, TermList, Classification, QueryRewrite
|
||||||
|
|||||||
Reference in New Issue
Block a user