Files
QueryRewrite/rag2_0/dify/intent_recognition_api.py
T

182 lines
6.0 KiB
Python
Executable File

import os
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from typing import Dict, List, Any, Optional
import asyncio
from dotenv import load_dotenv
import json
import time
import datetime
import logging
# 加载环境变量
load_dotenv()
import sys
sys.path.append(os.getcwd())
from rag2_0.intent_recognition import AsyncIntentRecognizer
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# 定义请求模型
class IntentRecognizeRequest(BaseModel):
query: str
conversation_context: str = ""
chat_history: Optional[List] = None
previous_slots: str | Dict = None
# 定义槽位填充响应模型
class SlotFillingResponse(BaseModel):
is_complete: bool = False
missing_slots: Dict[str, Any] = Field(default_factory=dict)
filled_data: Dict[str, Any] = Field(default_factory=dict)
class QueryExpandResponse(BaseModel):
all: List[str] = Field(default_factory=list)
step_back: Dict[str, Any] = Field(default_factory=Dict)
follow_up: Dict[str, Any] = Field(default_factory=Dict)
hyde: Dict[str, Any] = Field(default_factory=Dict)
multi_questions: Dict[str, Any] = Field(default_factory=Dict)
# 定义响应模型
class IntentRecognizeResponse(BaseModel):
source_query: str
source_query_keys: List[str]
vertical_classification: str
sub_classification: str
rewrite_query: str
keywords: List[Dict[str, str]] = Field(default_factory=list)
has_slot_filling: bool = False
slot_filling: SlotFillingResponse = Field(default_factory=SlotFillingResponse)
query_expand: QueryExpandResponse = Field(default_factory=QueryExpandResponse)
# 创建FastAPI应用
app = FastAPI(
title="意图识别服务",
description="基于LLM的意图识别和问题改写服务",
version="2.0"
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局变量存储AsyncIntentRecognizer实例
_instance = None
# 应用启动事件
@app.on_event("startup")
async def startup_event():
global _instance
# 初始化AsyncIntentRecognizer实例
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
_instance = await AsyncIntentRecognizer.create(api_key=api_key, base_url=base_url, model_name=model_name)
logger.info("AsyncIntentRecognizer初始化完成")
@app.post("/intent_recognize", response_model=IntentRecognizeResponse, summary="意图识别", description="识别用户查询的意图并进行问题改写")
async def intent_recognize(request: IntentRecognizeRequest):
try:
if not request.query:
raise HTTPException(status_code=400, detail="缺少query参数")
start_time = time.time()
result = await _instance.process_query_async(
query=request.query,
conversation_context=request.conversation_context,
chat_history=request.chat_history,
previous_slots=request.previous_slots,
use_jieba=True,
enable_query_expansion=True
)
end_time = time.time()
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
logger.info(f"[{os.getpid()}] 意图识别耗时: {end_time - start_time:.2f}秒")
# 提取分类信息
classification = result["classification"]
# 提取关键词信息
keywords = result["keywords"]
keywords_list = []
if keywords and keywords.get("terms"):
for term in keywords["terms"]:
keywords_list.append({
"名称": term["name"]
})
# 提取槽位填充信息
slot_filling = result.get("slot_filling", {})
# 构建响应
response = IntentRecognizeResponse(
source_query=request.query,
source_query_keys=result["query_keys"],
vertical_classification=classification["vertical_classification"],
sub_classification=classification["sub_classification"],
rewrite_query=result["rewrite"]["rewrite"],
keywords=keywords_list,
has_slot_filling=len(slot_filling) != 0,
slot_filling=SlotFillingResponse(
is_complete=slot_filling.get("is_complete", False),
missing_slots=slot_filling.get("missing_slots", {}),
filled_data=slot_filling.get("filled_data", {})
),
query_expand=QueryExpandResponse(
all=result["query_expand"]["all"],
step_back=result["query_expand"]["step_back"],
follow_up=result["query_expand"]["follow_up"],
hyde=result["query_expand"]["hyde"],
multi_questions=result["query_expand"]["multi_questions"]
)
)
return response
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"意图识别出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# 添加健康检查端点
@app.get("/health", summary="健康检查")
async def health_check():
return {"status": "ok"}
if __name__ == "__main__":
# 使用uvicorn启动服务
import uvicorn
uvicorn.run(
"rag2_0.dify.intent_recognition_api:app",
host="0.0.0.0",
port=8001,
reload=True, # 开发环境启用热重载
workers=1 # 生产环境可以增加worker数量
)
# 生产环境可以使用以下命令启动:
# uvicorn rag2_0.dify.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 10