优化意图识别API,移除同步意图识别器,改为使用异步意图识别器,更新相关逻辑以支持异步处理,增强错误处理和日志记录,同时更新请求和响应模型以适应新的API结构。
This commit is contained in:
@@ -186,7 +186,6 @@ class DifyComparisonTester:
|
||||
要求
|
||||
1、分析待评估的回答与标准答案的匹配程度(包括内容、步骤、主体等)
|
||||
2、如果待评估的回答与标准答案在核心内容和关键信息(步骤)上一致,即使表达方式不同,也应判定为"正确"。
|
||||
3、只要大体描述一致,即使缺失了一些步骤,也应判定为"正确"。
|
||||
3、如果待评估的回答存在明显的错误信息,应判定为"错误"。
|
||||
4、请严格按json格式输出:
|
||||
{{
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from gevent import monkey
|
||||
monkey.patch_all()
|
||||
|
||||
from flask import Flask, request, Response
|
||||
import os
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import Dict, List, Any, Optional
|
||||
import asyncio
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import json
|
||||
import time
|
||||
from gevent.lock import RLock
|
||||
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
# 加载环境变量
|
||||
@@ -16,7 +16,7 @@ load_dotenv()
|
||||
|
||||
import sys
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.intent_recognition import IntentRecognizer
|
||||
from rag2_0.intent_recognition import AsyncIntentRecognizer
|
||||
|
||||
|
||||
logging.basicConfig(
|
||||
@@ -31,32 +31,87 @@ logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
_instance = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
|
||||
@app.route('/intent_recognize', methods=['POST'])
|
||||
def intent_recognize():
|
||||
|
||||
# 定义请求模型
|
||||
class IntentRecognizeRequest(BaseModel):
|
||||
query: str
|
||||
conversation_context: str = ""
|
||||
chat_history: Optional[List] = None
|
||||
previous_slots: str | Dict = None
|
||||
|
||||
# 定义槽位填充响应模型
|
||||
class SlotFillingResponse(BaseModel):
|
||||
is_complete: bool = False
|
||||
missing_slots: Dict[str, Any] = Field(default_factory=dict)
|
||||
filled_data: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class QueryExpandResponse(BaseModel):
|
||||
all: List[str] = Field(default_factory=list)
|
||||
step_back: Dict[str, Any] = Field(default_factory=Dict)
|
||||
follow_up: Dict[str, Any] = Field(default_factory=Dict)
|
||||
hyde: Dict[str, Any] = Field(default_factory=Dict)
|
||||
multi_questions: Dict[str, Any] = Field(default_factory=Dict)
|
||||
|
||||
# 定义响应模型
|
||||
class IntentRecognizeResponse(BaseModel):
|
||||
source_query: str
|
||||
source_query_keys: List[str]
|
||||
vertical_classification: str
|
||||
sub_classification: str
|
||||
rewrite_query: str
|
||||
keywords: List[Dict[str, str]] = Field(default_factory=list)
|
||||
has_slot_filling: bool = False
|
||||
slot_filling: SlotFillingResponse = Field(default_factory=SlotFillingResponse)
|
||||
query_expand: QueryExpandResponse = Field(default_factory=QueryExpandResponse)
|
||||
|
||||
# 创建FastAPI应用
|
||||
app = FastAPI(
|
||||
title="意图识别服务",
|
||||
description="基于LLM的意图识别和问题改写服务",
|
||||
version="2.0"
|
||||
)
|
||||
|
||||
# 添加CORS中间件
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# 全局变量存储AsyncIntentRecognizer实例
|
||||
_instance = None
|
||||
|
||||
# 应用启动事件
|
||||
@app.on_event("startup")
|
||||
async def startup_event():
|
||||
global _instance
|
||||
# 初始化AsyncIntentRecognizer实例
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
_instance = await AsyncIntentRecognizer.create(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
logger.info("AsyncIntentRecognizer初始化完成")
|
||||
|
||||
@app.post("/intent_recognize", response_model=IntentRecognizeResponse, summary="意图识别", description="识别用户查询的意图并进行问题改写")
|
||||
async def intent_recognize(request: IntentRecognizeRequest):
|
||||
try:
|
||||
data = request.get_json(force=True)
|
||||
query = data.get('query')
|
||||
conversation_context = data.get('conversation_context', "")
|
||||
chat_history = data.get('chat_history', None)
|
||||
previous_slots = data.get('previous_slots', None)
|
||||
if not query:
|
||||
return Response(json.dumps({"error": "缺少query参数"}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=400)
|
||||
if not request.query:
|
||||
raise HTTPException(status_code=400, detail="缺少query参数")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
result = _instance.process_query(query=query,
|
||||
conversation_context=conversation_context,
|
||||
chat_history=chat_history,
|
||||
previous_slots=previous_slots,
|
||||
use_jieba=False,
|
||||
enable_query_expansion=True)
|
||||
result = await _instance.process_query_async(
|
||||
query=request.query,
|
||||
conversation_context=request.conversation_context,
|
||||
chat_history=request.chat_history,
|
||||
previous_slots=request.previous_slots,
|
||||
use_jieba=False,
|
||||
enable_query_expansion=True
|
||||
)
|
||||
|
||||
end_time = time.time()
|
||||
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
|
||||
@@ -67,41 +122,61 @@ def intent_recognize():
|
||||
|
||||
# 提取关键词信息
|
||||
keywords = result["keywords"]
|
||||
keywords_str = ""
|
||||
keywords_list = []
|
||||
if keywords and keywords.get("terms"):
|
||||
term_details = []
|
||||
for term in keywords["terms"]:
|
||||
term_info = {
|
||||
"名称": term["name"],
|
||||
}
|
||||
term_details.append(term_info)
|
||||
keywords_str = term_details
|
||||
keywords_list.append({
|
||||
"名称": term["name"]
|
||||
})
|
||||
|
||||
# 提取槽位填充信息
|
||||
slot_filling = result.get("slot_filling", {})
|
||||
|
||||
response_result = {
|
||||
"source_query": query,
|
||||
"source_query_keys": result["query_keys"],
|
||||
"vertical_classification": classification["vertical_classification"],
|
||||
"sub_classification": classification["sub_classification"],
|
||||
"rewrite_query": result["rewrite"]["rewrite"],
|
||||
"keywords": keywords_str,
|
||||
"has_slot_filling": len(slot_filling)!=0,
|
||||
"slot_filling": {
|
||||
"is_complete": slot_filling.get("is_complete", False),
|
||||
"missing_slots": slot_filling.get("missing_slots", {}),
|
||||
"filled_data": slot_filling.get("filled_data", {})
|
||||
},
|
||||
"query_expand": result["query_expand"]
|
||||
}
|
||||
return Response(json.dumps(response_result, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
||||
# 构建响应
|
||||
response = IntentRecognizeResponse(
|
||||
source_query=request.query,
|
||||
source_query_keys=result["query_keys"],
|
||||
vertical_classification=classification["vertical_classification"],
|
||||
sub_classification=classification["sub_classification"],
|
||||
rewrite_query=result["rewrite"]["rewrite"],
|
||||
keywords=keywords_list,
|
||||
has_slot_filling=len(slot_filling) != 0,
|
||||
slot_filling=SlotFillingResponse(
|
||||
is_complete=slot_filling.get("is_complete", False),
|
||||
missing_slots=slot_filling.get("missing_slots", {}),
|
||||
filled_data=slot_filling.get("filled_data", {})
|
||||
),
|
||||
query_expand=QueryExpandResponse(
|
||||
all=result["query_expand"]["all"],
|
||||
step_back=result["query_expand"]["step_back"],
|
||||
follow_up=result["query_expand"]["follow_up"],
|
||||
hyde=result["query_expand"]["hyde"],
|
||||
multi_questions=result["query_expand"]["multi_questions"]
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except HTTPException as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"意图识别出错: {str(e)}",exc_info=True)
|
||||
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
||||
logger.error(f"意图识别出错: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
# 添加健康检查端点
|
||||
@app.get("/health", summary="健康检查")
|
||||
async def health_check():
|
||||
return {"status": "ok"}
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 开发环境使用Flask内置服务器
|
||||
# 生产环境使用gunicorn支持高并发 uv run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app
|
||||
# uv run gunicorn -w 10 -k gevent --preload -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app
|
||||
app.run(host="0.0.0.0", port=8001, threaded=True)
|
||||
# 使用uvicorn启动服务
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"rag2_0.dify.intent_recognition_api:app",
|
||||
host="0.0.0.0",
|
||||
port=8001,
|
||||
reload=False, # 开发环境启用热重载
|
||||
workers=1 # 生产环境可以增加worker数量
|
||||
)
|
||||
# 生产环境可以使用以下命令启动:
|
||||
# uvicorn rag2_0.dify.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 10
|
||||
Reference in New Issue
Block a user