微调意图改写接口

This commit is contained in:
2025-07-22 11:35:05 +08:00
parent 75c7c19f53
commit 0a2d6c2020
2 changed files with 16 additions and 18 deletions
+7 -13
View File
@@ -2,7 +2,7 @@ import os
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ConfigDict
from typing import Dict, List, Any, Optional
import asyncio
@@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
# 定义请求模型
class IntentRecognizeRequest(BaseModel):
query: str
enable_query_expansion: bool = False
conversation_context: Dict = None
chat_history: Optional[List] = None
previous_slots: str | Dict = None
@@ -48,11 +49,9 @@ class SlotFillingResponse(BaseModel):
filled_data: Dict[str, Any] = Field(default_factory=dict)
class QueryExpandResponse(BaseModel):
# 必须包含的all字段
all: List[str] = Field(default_factory=list)
step_back: Dict[str, Any] = Field(default_factory=Dict)
follow_up: Dict[str, Any] = Field(default_factory=Dict)
# hyde: Dict[str, Any] = Field(default_factory=Dict)
multi_questions: Dict[str, Any] = Field(default_factory=Dict)
model_config = ConfigDict(extra='allow')
# 定义响应模型
class IntentRecognizeResponse(BaseModel):
@@ -104,6 +103,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
if not request.query:
raise HTTPException(status_code=400, detail="缺少query参数")
enable_query_expansion = request.enable_query_expansion
start_time = time.time()
current_softname = request.conversation_context.get("current_softname", "")
result = await _instance.process_query_async(
@@ -112,7 +112,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
chat_history=request.chat_history,
previous_slots=request.previous_slots,
use_jieba=True,
enable_query_expansion=True,
enable_query_expansion=enable_query_expansion,
cur_soft_name=current_softname
)
@@ -149,13 +149,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
missing_slots=slot_filling.get("missing_slots", {}),
filled_data=slot_filling.get("filled_data", {})
),
query_expand=QueryExpandResponse(
all=result["query_expand"]["all"],
step_back=result["query_expand"]["step_back"],
follow_up=result["query_expand"]["follow_up"],
# hyde=result["query_expand"]["hyde"],
multi_questions=result["query_expand"]["multi_questions"]
)
query_expand=QueryExpandResponse(**result["query_expand"])
)
return response