优化意图识别API,移除同步意图识别器,改为使用异步意图识别器,更新相关逻辑以支持异步处理,增强错误处理和日志记录,同时更新请求和响应模型以适应新的API结构。
This commit is contained in:
@@ -21,7 +21,7 @@ from typing import List, Dict, Any
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from pydantic import BaseModel, Field
|
||||
sys.path.append(os.getcwd())
|
||||
from rag2_0.intent_recognition import IntentRecognizer, AsyncIntentRecognizer
|
||||
from rag2_0.intent_recognition import AsyncIntentRecognizer
|
||||
from rag2_0.dify.DifyQueryRetrieval import DifyQueryRetrieval
|
||||
from rag2_0.intent_recognition.DataModels import Classification
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
@@ -78,8 +78,6 @@ class QueryRewriteProcessor:
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = base_url or os.getenv("OPENAI_API_BASE")
|
||||
self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
|
||||
self.recognizer = IntentRecognizer(api_key=self.api_key, base_url=self.base_url, model_name=self.model_name)
|
||||
# 使用asyncio.run()运行异步create方法
|
||||
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create(
|
||||
api_key=self.api_key,
|
||||
@@ -198,12 +196,12 @@ class QueryRewriteProcessor:
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 使用process_query方法处理查询
|
||||
result = self.recognizer.process_query(query,
|
||||
result = asyncio.run(self.recognizer_async.process_query_async(query,
|
||||
conversation_context=conversation_context,
|
||||
chat_history=chat_history,
|
||||
previous_slots=previous_slots,
|
||||
enable_query_expansion=True)
|
||||
enable_query_expansion=True))
|
||||
|
||||
# 提取分类信息
|
||||
classification = result["classification"]
|
||||
original_query = result["rewrite"]["rewrite"]
|
||||
@@ -238,9 +236,9 @@ class QueryRewriteProcessor:
|
||||
if slot_filling and "filled_data" in slot_filling:
|
||||
# 格式化槽位填充结果
|
||||
slot_filling_str = json.dumps({
|
||||
"是否完整": slot_filling.get("is_complete", False),
|
||||
"缺失槽位": slot_filling.get("missing_slots", {}),
|
||||
"填充数据": slot_filling.get("filled_data", {})
|
||||
"is_complete": slot_filling.get("is_complete", False),
|
||||
"missing_slots": slot_filling.get("missing_slots", {}),
|
||||
"filled_data": slot_filling.get("filled_data", {})
|
||||
}, ensure_ascii=False, indent=2)
|
||||
|
||||
# 处理成功,返回结果
|
||||
@@ -442,9 +440,12 @@ def main():
|
||||
for idx, query in enumerate(examples):
|
||||
if query.strip() == "":
|
||||
continue
|
||||
query="储能C1软件如何新建工程?"
|
||||
conversation_context="当前使用软件:配网计价通D3软件"
|
||||
# 在调试模式下使用完整的参数
|
||||
print(json.dumps(processor.process_query(
|
||||
query,
|
||||
conversation_context=conversation_context,
|
||||
enable_retrieval=True
|
||||
), ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
Reference in New Issue
Block a user