185 lines
6.1 KiB
Python
Executable File
185 lines
6.1 KiB
Python
Executable File
import os
|
|
from fastapi import FastAPI, HTTPException, Request
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel, Field
|
|
from typing import Dict, List, Any, Optional
|
|
import asyncio
|
|
|
|
from dotenv import load_dotenv
|
|
import json
|
|
import time
|
|
import datetime
|
|
import logging
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
import sys
|
|
sys.path.append(os.getcwd())
|
|
from rag2_0.intent_recognition import AsyncIntentRecognizer
|
|
|
|
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
# 定义请求模型
|
|
class IntentRecognizeRequest(BaseModel):
|
|
query: str
|
|
conversation_context: Dict = None
|
|
chat_history: Optional[List] = None
|
|
previous_slots: str | Dict = None
|
|
|
|
# 定义槽位填充响应模型
|
|
class SlotFillingResponse(BaseModel):
|
|
is_complete: bool = False
|
|
missing_slots: Dict[str, Any] = Field(default_factory=dict)
|
|
filled_data: Dict[str, Any] = Field(default_factory=dict)
|
|
|
|
class QueryExpandResponse(BaseModel):
|
|
all: List[str] = Field(default_factory=list)
|
|
step_back: Dict[str, Any] = Field(default_factory=Dict)
|
|
follow_up: Dict[str, Any] = Field(default_factory=Dict)
|
|
# hyde: Dict[str, Any] = Field(default_factory=Dict)
|
|
multi_questions: Dict[str, Any] = Field(default_factory=Dict)
|
|
|
|
# 定义响应模型
|
|
class IntentRecognizeResponse(BaseModel):
|
|
source_query: str
|
|
source_query_keys: List[str]
|
|
vertical_classification: str
|
|
sub_classification: str
|
|
rewrite_query: str
|
|
keywords: List[Dict[str, str]] = Field(default_factory=list)
|
|
has_slot_filling: bool = False
|
|
slot_filling: SlotFillingResponse = Field(default_factory=SlotFillingResponse)
|
|
query_expand: QueryExpandResponse = Field(default_factory=QueryExpandResponse)
|
|
|
|
# 创建FastAPI应用
|
|
app = FastAPI(
|
|
title="意图识别服务",
|
|
description="基于LLM的意图识别和问题改写服务",
|
|
version="2.0"
|
|
)
|
|
|
|
# 添加CORS中间件
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=True,
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
# 全局变量存储AsyncIntentRecognizer实例
|
|
_instance = None
|
|
|
|
# 应用启动事件
|
|
@app.on_event("startup")
|
|
async def startup_event():
|
|
global _instance
|
|
_instance = await AsyncIntentRecognizer.create()
|
|
logger.info("AsyncIntentRecognizer初始化完成")
|
|
|
|
@app.post("/intent_recognize1")
|
|
async def intent_recognize(request: Request):
|
|
data = await request.json()
|
|
print(data)
|
|
return {"message": "success"}
|
|
|
|
@app.post("/intent_recognize", response_model=IntentRecognizeResponse, summary="意图识别", description="识别用户查询的意图并进行问题改写")
|
|
async def intent_recognize(request: IntentRecognizeRequest):
|
|
try:
|
|
if not request.query:
|
|
raise HTTPException(status_code=400, detail="缺少query参数")
|
|
|
|
start_time = time.time()
|
|
current_softname = request.conversation_context.get("current_softname", "")
|
|
result = await _instance.process_query_async(
|
|
query=request.query,
|
|
conversation_context=request.conversation_context,
|
|
chat_history=request.chat_history,
|
|
previous_slots=request.previous_slots,
|
|
use_jieba=True,
|
|
enable_query_expansion=True,
|
|
cur_soft_name=current_softname
|
|
)
|
|
|
|
end_time = time.time()
|
|
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
|
|
logger.info(f"[{os.getpid()}] 意图识别耗时: {end_time - start_time:.2f}秒")
|
|
|
|
# 提取分类信息
|
|
classification = result["classification"]
|
|
|
|
# 提取关键词信息
|
|
keywords = result["keywords"]
|
|
keywords_list = []
|
|
if keywords and keywords.get("terms"):
|
|
for term in keywords["terms"]:
|
|
keywords_list.append({
|
|
"名称": term["name"]
|
|
})
|
|
|
|
# 提取槽位填充信息
|
|
slot_filling = result.get("slot_filling", {})
|
|
|
|
# 构建响应
|
|
response = IntentRecognizeResponse(
|
|
source_query=request.query,
|
|
source_query_keys=result["query_keys"],
|
|
vertical_classification=classification["vertical_classification"],
|
|
sub_classification=classification["sub_classification"],
|
|
rewrite_query=result["rewrite"]["rewrite"],
|
|
keywords=keywords_list,
|
|
has_slot_filling=len(slot_filling) != 0,
|
|
slot_filling=SlotFillingResponse(
|
|
is_complete=slot_filling.get("is_complete", False),
|
|
missing_slots=slot_filling.get("missing_slots", {}),
|
|
filled_data=slot_filling.get("filled_data", {})
|
|
),
|
|
query_expand=QueryExpandResponse(
|
|
all=result["query_expand"]["all"],
|
|
step_back=result["query_expand"]["step_back"],
|
|
follow_up=result["query_expand"]["follow_up"],
|
|
# hyde=result["query_expand"]["hyde"],
|
|
multi_questions=result["query_expand"]["multi_questions"]
|
|
)
|
|
)
|
|
|
|
return response
|
|
|
|
except HTTPException as e:
|
|
raise e
|
|
except Exception as e:
|
|
logger.error(f"意图识别出错: {str(e)}", exc_info=True)
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
# 添加健康检查端点
|
|
@app.get("/health", summary="健康检查")
|
|
async def health_check():
|
|
return {"status": "ok"}
|
|
|
|
if __name__ == "__main__":
|
|
# 使用uvicorn启动服务
|
|
import uvicorn
|
|
uvicorn.run(
|
|
"rag2_0.dify.intent_recognition_api:app",
|
|
host="0.0.0.0",
|
|
port=8001,
|
|
reload=True, # 开发环境启用热重载
|
|
workers=1 # 生产环境可以增加worker数量
|
|
)
|
|
# 生产环境可以使用以下命令启动:
|
|
# uvicorn rag2_0.dify.intent_recognition_api:app --host 0.0.0.0 --port 8001 --workers 10 |