diff --git a/rag2_0/dify/intent_recognition_api.py b/rag2_0/dify/intent_recognition_api.py index ae577b4..3f82846 100755 --- a/rag2_0/dify/intent_recognition_api.py +++ b/rag2_0/dify/intent_recognition_api.py @@ -2,7 +2,7 @@ import os from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, ConfigDict from typing import Dict, List, Any, Optional import asyncio @@ -37,6 +37,7 @@ logger = logging.getLogger(__name__) # 定义请求模型 class IntentRecognizeRequest(BaseModel): query: str + enable_query_expansion: bool = False conversation_context: Dict = None chat_history: Optional[List] = None previous_slots: str | Dict = None @@ -48,11 +49,9 @@ class SlotFillingResponse(BaseModel): filled_data: Dict[str, Any] = Field(default_factory=dict) class QueryExpandResponse(BaseModel): + # 必须包含的all字段 all: List[str] = Field(default_factory=list) - step_back: Dict[str, Any] = Field(default_factory=Dict) - follow_up: Dict[str, Any] = Field(default_factory=Dict) - # hyde: Dict[str, Any] = Field(default_factory=Dict) - multi_questions: Dict[str, Any] = Field(default_factory=Dict) + model_config = ConfigDict(extra='allow') # 定义响应模型 class IntentRecognizeResponse(BaseModel): @@ -104,6 +103,7 @@ async def intent_recognize(request: IntentRecognizeRequest): if not request.query: raise HTTPException(status_code=400, detail="缺少query参数") + enable_query_expansion = request.enable_query_expansion start_time = time.time() current_softname = request.conversation_context.get("current_softname", "") result = await _instance.process_query_async( @@ -112,7 +112,7 @@ async def intent_recognize(request: IntentRecognizeRequest): chat_history=request.chat_history, previous_slots=request.previous_slots, use_jieba=True, - enable_query_expansion=True, + enable_query_expansion=enable_query_expansion, cur_soft_name=current_softname ) @@ -149,13 +149,7 @@ async def intent_recognize(request: IntentRecognizeRequest): missing_slots=slot_filling.get("missing_slots", {}), filled_data=slot_filling.get("filled_data", {}) ), - query_expand=QueryExpandResponse( - all=result["query_expand"]["all"], - step_back=result["query_expand"]["step_back"], - follow_up=result["query_expand"]["follow_up"], - # hyde=result["query_expand"]["hyde"], - multi_questions=result["query_expand"]["multi_questions"] - ) + query_expand=QueryExpandResponse(**result["query_expand"]) ) return response diff --git a/rag2_0/intent_recognition/IntentRecognition.py b/rag2_0/intent_recognition/IntentRecognition.py index 2b5dc31..6514585 100755 --- a/rag2_0/intent_recognition/IntentRecognition.py +++ b/rag2_0/intent_recognition/IntentRecognition.py @@ -492,19 +492,23 @@ class AsyncIntentRecognizer: wiki_result = query_expand_results[2] if query_expand_results[2] else [] multi_questions_result = query_expand_results[3] if query_expand_results[3] else MultiQuestions(original_query=query, sub_questions=[query]) - all_questions = multi_questions_result.sub_questions + all_questions=[] all_questions.append(query) all_questions.append(rewrite.rewrite) + all_questions.extend(wiki_result) all_questions.extend(step_back_result.step_back_query) all_questions.append(follow_up_result.follow_up_query) - all_questions.extend(wiki_result) + all_questions.extend(multi_questions_result.sub_questions) all_questions = list(set(all_questions)) query_expand = { "all": all_questions, - "step_back": step_back_result.model_dump(), - "follow_up": follow_up_result.model_dump(), - "multi_questions": multi_questions_result.model_dump(), + "step_back": step_back_result.step_back_query, + "follow_up": [follow_up_result.follow_up_query], + "multi_questions": multi_questions_result.sub_questions, + "wiki_title": wiki_result, + "original_query":query, + "rewrite_query":rewrite.rewrite } # 返回所有结果