微调意图改写接口
This commit is contained in:
@@ -2,7 +2,7 @@ import os
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Dict, List, Any, Optional
|
||||
import asyncio
|
||||
|
||||
@@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
|
||||
# 定义请求模型
|
||||
class IntentRecognizeRequest(BaseModel):
|
||||
query: str
|
||||
enable_query_expansion: bool = False
|
||||
conversation_context: Dict = None
|
||||
chat_history: Optional[List] = None
|
||||
previous_slots: str | Dict = None
|
||||
@@ -48,11 +49,9 @@ class SlotFillingResponse(BaseModel):
|
||||
filled_data: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class QueryExpandResponse(BaseModel):
|
||||
# 必须包含的all字段
|
||||
all: List[str] = Field(default_factory=list)
|
||||
step_back: Dict[str, Any] = Field(default_factory=Dict)
|
||||
follow_up: Dict[str, Any] = Field(default_factory=Dict)
|
||||
# hyde: Dict[str, Any] = Field(default_factory=Dict)
|
||||
multi_questions: Dict[str, Any] = Field(default_factory=Dict)
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
# 定义响应模型
|
||||
class IntentRecognizeResponse(BaseModel):
|
||||
@@ -104,6 +103,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
||||
if not request.query:
|
||||
raise HTTPException(status_code=400, detail="缺少query参数")
|
||||
|
||||
enable_query_expansion = request.enable_query_expansion
|
||||
start_time = time.time()
|
||||
current_softname = request.conversation_context.get("current_softname", "")
|
||||
result = await _instance.process_query_async(
|
||||
@@ -112,7 +112,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
||||
chat_history=request.chat_history,
|
||||
previous_slots=request.previous_slots,
|
||||
use_jieba=True,
|
||||
enable_query_expansion=True,
|
||||
enable_query_expansion=enable_query_expansion,
|
||||
cur_soft_name=current_softname
|
||||
)
|
||||
|
||||
@@ -149,13 +149,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
||||
missing_slots=slot_filling.get("missing_slots", {}),
|
||||
filled_data=slot_filling.get("filled_data", {})
|
||||
),
|
||||
query_expand=QueryExpandResponse(
|
||||
all=result["query_expand"]["all"],
|
||||
step_back=result["query_expand"]["step_back"],
|
||||
follow_up=result["query_expand"]["follow_up"],
|
||||
# hyde=result["query_expand"]["hyde"],
|
||||
multi_questions=result["query_expand"]["multi_questions"]
|
||||
)
|
||||
query_expand=QueryExpandResponse(**result["query_expand"])
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
Reference in New Issue
Block a user