微调意图改写接口

This commit is contained in:
2025-07-22 11:35:05 +08:00
parent 75c7c19f53
commit 0a2d6c2020
2 changed files with 16 additions and 18 deletions
+7 -13
View File
@@ -2,7 +2,7 @@ import os
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, ConfigDict
from typing import Dict, List, Any, Optional from typing import Dict, List, Any, Optional
import asyncio import asyncio
@@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
# 定义请求模型 # 定义请求模型
class IntentRecognizeRequest(BaseModel): class IntentRecognizeRequest(BaseModel):
query: str query: str
enable_query_expansion: bool = False
conversation_context: Dict = None conversation_context: Dict = None
chat_history: Optional[List] = None chat_history: Optional[List] = None
previous_slots: str | Dict = None previous_slots: str | Dict = None
@@ -48,11 +49,9 @@ class SlotFillingResponse(BaseModel):
filled_data: Dict[str, Any] = Field(default_factory=dict) filled_data: Dict[str, Any] = Field(default_factory=dict)
class QueryExpandResponse(BaseModel): class QueryExpandResponse(BaseModel):
# 必须包含的all字段
all: List[str] = Field(default_factory=list) all: List[str] = Field(default_factory=list)
step_back: Dict[str, Any] = Field(default_factory=Dict) model_config = ConfigDict(extra='allow')
follow_up: Dict[str, Any] = Field(default_factory=Dict)
# hyde: Dict[str, Any] = Field(default_factory=Dict)
multi_questions: Dict[str, Any] = Field(default_factory=Dict)
# 定义响应模型 # 定义响应模型
class IntentRecognizeResponse(BaseModel): class IntentRecognizeResponse(BaseModel):
@@ -104,6 +103,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
if not request.query: if not request.query:
raise HTTPException(status_code=400, detail="缺少query参数") raise HTTPException(status_code=400, detail="缺少query参数")
enable_query_expansion = request.enable_query_expansion
start_time = time.time() start_time = time.time()
current_softname = request.conversation_context.get("current_softname", "") current_softname = request.conversation_context.get("current_softname", "")
result = await _instance.process_query_async( result = await _instance.process_query_async(
@@ -112,7 +112,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
chat_history=request.chat_history, chat_history=request.chat_history,
previous_slots=request.previous_slots, previous_slots=request.previous_slots,
use_jieba=True, use_jieba=True,
enable_query_expansion=True, enable_query_expansion=enable_query_expansion,
cur_soft_name=current_softname cur_soft_name=current_softname
) )
@@ -149,13 +149,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
missing_slots=slot_filling.get("missing_slots", {}), missing_slots=slot_filling.get("missing_slots", {}),
filled_data=slot_filling.get("filled_data", {}) filled_data=slot_filling.get("filled_data", {})
), ),
query_expand=QueryExpandResponse( query_expand=QueryExpandResponse(**result["query_expand"])
all=result["query_expand"]["all"],
step_back=result["query_expand"]["step_back"],
follow_up=result["query_expand"]["follow_up"],
# hyde=result["query_expand"]["hyde"],
multi_questions=result["query_expand"]["multi_questions"]
)
) )
return response return response
@@ -492,19 +492,23 @@ class AsyncIntentRecognizer:
wiki_result = query_expand_results[2] if query_expand_results[2] else [] wiki_result = query_expand_results[2] if query_expand_results[2] else []
multi_questions_result = query_expand_results[3] if query_expand_results[3] else MultiQuestions(original_query=query, sub_questions=[query]) multi_questions_result = query_expand_results[3] if query_expand_results[3] else MultiQuestions(original_query=query, sub_questions=[query])
all_questions = multi_questions_result.sub_questions all_questions=[]
all_questions.append(query) all_questions.append(query)
all_questions.append(rewrite.rewrite) all_questions.append(rewrite.rewrite)
all_questions.extend(wiki_result)
all_questions.extend(step_back_result.step_back_query) all_questions.extend(step_back_result.step_back_query)
all_questions.append(follow_up_result.follow_up_query) all_questions.append(follow_up_result.follow_up_query)
all_questions.extend(wiki_result) all_questions.extend(multi_questions_result.sub_questions)
all_questions = list(set(all_questions)) all_questions = list(set(all_questions))
query_expand = { query_expand = {
"all": all_questions, "all": all_questions,
"step_back": step_back_result.model_dump(), "step_back": step_back_result.step_back_query,
"follow_up": follow_up_result.model_dump(), "follow_up": [follow_up_result.follow_up_query],
"multi_questions": multi_questions_result.model_dump(), "multi_questions": multi_questions_result.sub_questions,
"wiki_title": wiki_result,
"original_query":query,
"rewrite_query":rewrite.rewrite
} }
# 返回所有结果 # 返回所有结果