微调意图改写接口
This commit is contained in:
@@ -2,7 +2,7 @@ import os
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from typing import Dict, List, Any, Optional
|
||||
import asyncio
|
||||
|
||||
@@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
|
||||
# 定义请求模型
|
||||
class IntentRecognizeRequest(BaseModel):
|
||||
query: str
|
||||
enable_query_expansion: bool = False
|
||||
conversation_context: Dict = None
|
||||
chat_history: Optional[List] = None
|
||||
previous_slots: str | Dict = None
|
||||
@@ -48,11 +49,9 @@ class SlotFillingResponse(BaseModel):
|
||||
filled_data: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
class QueryExpandResponse(BaseModel):
|
||||
# 必须包含的all字段
|
||||
all: List[str] = Field(default_factory=list)
|
||||
step_back: Dict[str, Any] = Field(default_factory=Dict)
|
||||
follow_up: Dict[str, Any] = Field(default_factory=Dict)
|
||||
# hyde: Dict[str, Any] = Field(default_factory=Dict)
|
||||
multi_questions: Dict[str, Any] = Field(default_factory=Dict)
|
||||
model_config = ConfigDict(extra='allow')
|
||||
|
||||
# 定义响应模型
|
||||
class IntentRecognizeResponse(BaseModel):
|
||||
@@ -104,6 +103,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
||||
if not request.query:
|
||||
raise HTTPException(status_code=400, detail="缺少query参数")
|
||||
|
||||
enable_query_expansion = request.enable_query_expansion
|
||||
start_time = time.time()
|
||||
current_softname = request.conversation_context.get("current_softname", "")
|
||||
result = await _instance.process_query_async(
|
||||
@@ -112,7 +112,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
||||
chat_history=request.chat_history,
|
||||
previous_slots=request.previous_slots,
|
||||
use_jieba=True,
|
||||
enable_query_expansion=True,
|
||||
enable_query_expansion=enable_query_expansion,
|
||||
cur_soft_name=current_softname
|
||||
)
|
||||
|
||||
@@ -149,13 +149,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
||||
missing_slots=slot_filling.get("missing_slots", {}),
|
||||
filled_data=slot_filling.get("filled_data", {})
|
||||
),
|
||||
query_expand=QueryExpandResponse(
|
||||
all=result["query_expand"]["all"],
|
||||
step_back=result["query_expand"]["step_back"],
|
||||
follow_up=result["query_expand"]["follow_up"],
|
||||
# hyde=result["query_expand"]["hyde"],
|
||||
multi_questions=result["query_expand"]["multi_questions"]
|
||||
)
|
||||
query_expand=QueryExpandResponse(**result["query_expand"])
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -492,19 +492,23 @@ class AsyncIntentRecognizer:
|
||||
wiki_result = query_expand_results[2] if query_expand_results[2] else []
|
||||
multi_questions_result = query_expand_results[3] if query_expand_results[3] else MultiQuestions(original_query=query, sub_questions=[query])
|
||||
|
||||
all_questions = multi_questions_result.sub_questions
|
||||
all_questions=[]
|
||||
all_questions.append(query)
|
||||
all_questions.append(rewrite.rewrite)
|
||||
all_questions.extend(wiki_result)
|
||||
all_questions.extend(step_back_result.step_back_query)
|
||||
all_questions.append(follow_up_result.follow_up_query)
|
||||
all_questions.extend(wiki_result)
|
||||
all_questions.extend(multi_questions_result.sub_questions)
|
||||
all_questions = list(set(all_questions))
|
||||
|
||||
query_expand = {
|
||||
"all": all_questions,
|
||||
"step_back": step_back_result.model_dump(),
|
||||
"follow_up": follow_up_result.model_dump(),
|
||||
"multi_questions": multi_questions_result.model_dump(),
|
||||
"step_back": step_back_result.step_back_query,
|
||||
"follow_up": [follow_up_result.follow_up_query],
|
||||
"multi_questions": multi_questions_result.sub_questions,
|
||||
"wiki_title": wiki_result,
|
||||
"original_query":query,
|
||||
"rewrite_query":rewrite.rewrite
|
||||
}
|
||||
|
||||
# 返回所有结果
|
||||
|
||||
Reference in New Issue
Block a user