Files
QueryRewrite/rag2_0/api/intent_recognition_api.py
T
2025-08-28 14:30:07 +08:00

185 lines
6.1 KiB
Python
Executable File

import os
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field, ConfigDict
from typing import Dict, List, Any, Optional
import asyncio
from dotenv import load_dotenv
import json
import time
import datetime
import logging
# 加载环境变量
load_dotenv()
import sys
sys.path.append(os.getcwd())
from rag2_0.intent_recognition import AsyncIntentRecognizer
# 确保日志目录存在
os.makedirs('data/logs', exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(process)d - %(thread)d - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(),
logging.FileHandler(f'data/logs/intent_recognition_{datetime.datetime.now().strftime("%Y%m%d")}.log', encoding='utf-8')
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
logger = logging.getLogger(__name__)
# 定义请求模型
class IntentRecognizeRequest(BaseModel):
query: str
enable_query_expansion: bool = True
conversation_context: Dict = None
chat_history: Optional[List] = None
previous_slots: str | Dict = None
# 定义槽位填充响应模型
class SlotFillingResponse(BaseModel):
is_complete: bool = False
missing_slots: Dict[str, Any] = Field(default_factory=dict)
filled_data: Dict[str, Any] = Field(default_factory=dict)
class QueryExpandResponse(BaseModel):
# 必须包含的all字段
all: List[str] = Field(default_factory=list)
model_config = ConfigDict(extra='allow')
# 定义响应模型
class IntentRecognizeResponse(BaseModel):
source_query: str
source_query_keys: List[str]
vertical_classification: str
sub_classification: str
rewrite_query: str
keywords: List[Dict[str, str]] = Field(default_factory=list)
has_slot_filling: bool = False
slot_filling: SlotFillingResponse = Field(default_factory=SlotFillingResponse)
query_expand: QueryExpandResponse = Field(default_factory=QueryExpandResponse)
dinge_qingdan_info: Dict[str, Any] = Field(default_factory=dict)
# 创建FastAPI应用
app = FastAPI(
title="意图识别服务",
description="基于LLM的意图识别和问题改写服务",
version="2.0"
)
# 添加CORS中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 全局变量存储AsyncIntentRecognizer实例
_instance = None
# 应用启动事件
@app.on_event("startup")
async def startup_event():
global _instance
_instance = await AsyncIntentRecognizer.create()
logger.info("AsyncIntentRecognizer初始化完成")
@app.post("/intent_recognize1")
async def intent_recognize(request: Request):
data = await request.json()
print(data)
return {"message": "success"}
@app.post("/intent_recognize", response_model=IntentRecognizeResponse, summary="意图识别", description="识别用户查询的意图并进行问题改写")
async def intent_recognize(request: IntentRecognizeRequest):
try:
if not request.query:
raise HTTPException(status_code=400, detail="缺少query参数")
enable_query_expansion = request.enable_query_expansion
start_time = time.time()
current_softname = request.conversation_context.get("current_softname", "")
result = await _instance.process_query_async(
query=request.query,
conversation_context=request.conversation_context,
chat_history=request.chat_history,
previous_slots=request.previous_slots,
use_jieba=True,
enable_query_expansion=enable_query_expansion,
cur_soft_name=current_softname
)
dinge_qingdan_info = result["dinge_qingdan_info"]
end_time = time.time()
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
logger.info(f"意图识别耗时: {end_time - start_time:.2f}秒")
# 提取分类信息
classification = result["classification"]
# 提取关键词信息
keywords = result["keywords"]
keywords_list = []
if keywords and keywords.get("terms"):
for term in keywords["terms"]:
keywords_list.append({
"名称": term["name"]
})
# 提取槽位填充信息
slot_filling = result.get("slot_filling", {})
# 构建响应
response = IntentRecognizeResponse(
source_query=request.query,
source_query_keys=result["query_keys"],
vertical_classification=classification["vertical_classification"],
sub_classification=classification["sub_classification"],
rewrite_query=result["rewrite"]["rewrite"],
keywords=keywords_list,
has_slot_filling=len(slot_filling) != 0,
slot_filling=SlotFillingResponse(
is_complete=slot_filling.get("is_complete", False),
missing_slots=slot_filling.get("missing_slots", {}),
filled_data=slot_filling.get("filled_data", {})
),
query_expand=QueryExpandResponse(**result["query_expand"]),
dinge_qingdan_info=dinge_qingdan_info
)
return response
except HTTPException as e:
raise e
except Exception as e:
logger.error(f"意图识别出错: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
# 添加健康检查端点
@app.get("/health", summary="健康检查")
async def health_check():
return {"status": "ok"}
if __name__ == "__main__":
# 使用uvicorn启动服务
import uvicorn
uvicorn.run(
"rag2_0.dify.intent_recognition_api:app",
host="0.0.0.0",
port=9001,
reload=True, # 开发环境启用热重载
workers=1 # 生产环境可以增加worker数量
)
# 生产环境可以使用以下命令启动:
# uvicorn rag2_0.api.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 10