Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d0788ac32a | |||
| 0a2d6c2020 |
@@ -2,7 +2,7 @@ import os
|
|||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
from typing import Dict, List, Any, Optional
|
from typing import Dict, List, Any, Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
@@ -37,6 +37,7 @@ logger = logging.getLogger(__name__)
|
|||||||
# 定义请求模型
|
# 定义请求模型
|
||||||
class IntentRecognizeRequest(BaseModel):
|
class IntentRecognizeRequest(BaseModel):
|
||||||
query: str
|
query: str
|
||||||
|
enable_query_expansion: bool = True
|
||||||
conversation_context: Dict = None
|
conversation_context: Dict = None
|
||||||
chat_history: Optional[List] = None
|
chat_history: Optional[List] = None
|
||||||
previous_slots: str | Dict = None
|
previous_slots: str | Dict = None
|
||||||
@@ -48,11 +49,9 @@ class SlotFillingResponse(BaseModel):
|
|||||||
filled_data: Dict[str, Any] = Field(default_factory=dict)
|
filled_data: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
class QueryExpandResponse(BaseModel):
|
class QueryExpandResponse(BaseModel):
|
||||||
|
# 必须包含的all字段
|
||||||
all: List[str] = Field(default_factory=list)
|
all: List[str] = Field(default_factory=list)
|
||||||
step_back: Dict[str, Any] = Field(default_factory=Dict)
|
model_config = ConfigDict(extra='allow')
|
||||||
follow_up: Dict[str, Any] = Field(default_factory=Dict)
|
|
||||||
# hyde: Dict[str, Any] = Field(default_factory=Dict)
|
|
||||||
multi_questions: Dict[str, Any] = Field(default_factory=Dict)
|
|
||||||
|
|
||||||
# 定义响应模型
|
# 定义响应模型
|
||||||
class IntentRecognizeResponse(BaseModel):
|
class IntentRecognizeResponse(BaseModel):
|
||||||
@@ -104,6 +103,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
|||||||
if not request.query:
|
if not request.query:
|
||||||
raise HTTPException(status_code=400, detail="缺少query参数")
|
raise HTTPException(status_code=400, detail="缺少query参数")
|
||||||
|
|
||||||
|
enable_query_expansion = request.enable_query_expansion
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
current_softname = request.conversation_context.get("current_softname", "")
|
current_softname = request.conversation_context.get("current_softname", "")
|
||||||
result = await _instance.process_query_async(
|
result = await _instance.process_query_async(
|
||||||
@@ -112,7 +112,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
|||||||
chat_history=request.chat_history,
|
chat_history=request.chat_history,
|
||||||
previous_slots=request.previous_slots,
|
previous_slots=request.previous_slots,
|
||||||
use_jieba=True,
|
use_jieba=True,
|
||||||
enable_query_expansion=True,
|
enable_query_expansion=enable_query_expansion,
|
||||||
cur_soft_name=current_softname
|
cur_soft_name=current_softname
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -149,13 +149,7 @@ async def intent_recognize(request: IntentRecognizeRequest):
|
|||||||
missing_slots=slot_filling.get("missing_slots", {}),
|
missing_slots=slot_filling.get("missing_slots", {}),
|
||||||
filled_data=slot_filling.get("filled_data", {})
|
filled_data=slot_filling.get("filled_data", {})
|
||||||
),
|
),
|
||||||
query_expand=QueryExpandResponse(
|
query_expand=QueryExpandResponse(**result["query_expand"])
|
||||||
all=result["query_expand"]["all"],
|
|
||||||
step_back=result["query_expand"]["step_back"],
|
|
||||||
follow_up=result["query_expand"]["follow_up"],
|
|
||||||
# hyde=result["query_expand"]["hyde"],
|
|
||||||
multi_questions=result["query_expand"]["multi_questions"]
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|||||||
@@ -492,19 +492,23 @@ class AsyncIntentRecognizer:
|
|||||||
wiki_result = query_expand_results[2] if query_expand_results[2] else []
|
wiki_result = query_expand_results[2] if query_expand_results[2] else []
|
||||||
multi_questions_result = query_expand_results[3] if query_expand_results[3] else MultiQuestions(original_query=query, sub_questions=[query])
|
multi_questions_result = query_expand_results[3] if query_expand_results[3] else MultiQuestions(original_query=query, sub_questions=[query])
|
||||||
|
|
||||||
all_questions = multi_questions_result.sub_questions
|
all_questions=[]
|
||||||
all_questions.append(query)
|
all_questions.append(query)
|
||||||
all_questions.append(rewrite.rewrite)
|
all_questions.append(rewrite.rewrite)
|
||||||
|
all_questions.extend(wiki_result)
|
||||||
all_questions.extend(step_back_result.step_back_query)
|
all_questions.extend(step_back_result.step_back_query)
|
||||||
all_questions.append(follow_up_result.follow_up_query)
|
all_questions.append(follow_up_result.follow_up_query)
|
||||||
all_questions.extend(wiki_result)
|
all_questions.extend(multi_questions_result.sub_questions)
|
||||||
all_questions = list(set(all_questions))
|
all_questions = list(set(all_questions))
|
||||||
|
|
||||||
query_expand = {
|
query_expand = {
|
||||||
"all": all_questions,
|
"all": all_questions,
|
||||||
"step_back": step_back_result.model_dump(),
|
"step_back": step_back_result.step_back_query,
|
||||||
"follow_up": follow_up_result.model_dump(),
|
"follow_up": [follow_up_result.follow_up_query],
|
||||||
"multi_questions": multi_questions_result.model_dump(),
|
"multi_questions": multi_questions_result.sub_questions,
|
||||||
|
"wiki_title": wiki_result,
|
||||||
|
"original_query":query,
|
||||||
|
"rewrite_query":rewrite.rewrite
|
||||||
}
|
}
|
||||||
|
|
||||||
# 返回所有结果
|
# 返回所有结果
|
||||||
|
|||||||
Reference in New Issue
Block a user