diff --git a/rag2_0/intent_recognition/IntentRecognition.py b/rag2_0/intent_recognition/IntentRecognition.py
index 0fa30dd..1fca056 100755
--- a/rag2_0/intent_recognition/IntentRecognition.py
+++ b/rag2_0/intent_recognition/IntentRecognition.py
@@ -19,8 +19,6 @@ import jieba
import time
import threading
-from langchain_openai import ChatOpenAI
-
from .PromptTemplates import (classification_prompt, query_rewrite_prompt_pro,
extract_nouns_prompt, classification_info,
slot_filling_prompt, step_back_prompt,
@@ -34,10 +32,7 @@ from .DataModels import (
StepBackPrompt, HypotheticalDocument
)
from .ProfessionalNounVector import ProfessionalNounRetriever, AsyncProfessionalNounRetriever
-from rag2_0.tool.APIKeyManager import APIKeyManager
-
-TEMPERATURE = 0.4
-TOP_P = 0.7
+from rag2_0.tool.ModelTool import OpenAiLLM
class AsyncIntentRecognizer:
SOFT_WIKI_PATH = "data/wiki_data"
@@ -64,7 +59,17 @@ class AsyncIntentRecognizer:
model_name: 要使用的模型名称
vector_index_dir: 向量索引目录,如果为None则使用默认目录
"""
+ base_url = os.getenv("OPENAI_API_BASE")
+ model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
+ # 初始化LLM
+ llm_params = {
+ "temperature": 0.4, # 降低随机性,使结果更确定
+ "top_p": 0.7,
+ "model": model_name,
+ "base_url": base_url
+ }
+ self._llm = OpenAiLLM(**llm_params)
# 加载suffix关键词
self._suffix_keywords = self._load_suffix_keywords()
# 加载软件词条名称库
@@ -189,15 +194,7 @@ class AsyncIntentRecognizer:
# 解析输出
try:
# 异步调用LLM
- llm = ChatOpenAI(
- api_key=APIKeyManager.get_api_key(),
- openai_api_base=os.getenv("OPENAI_API_BASE"),
- model_name=os.getenv("MODEL_NAME"),
- temperature=TEMPERATURE,
- top_p=TOP_P
- )
- llm.with_structured_output(Classification)
- response = await llm.ainvoke(formatted_prompt)
+ response = await self._llm.ainvoke(formatted_prompt, extra_body={"enable_thinking": False})
# 尝试直接解析JSON响应
response.content = response.content.strip()
@@ -264,17 +261,8 @@ class AsyncIntentRecognizer:
terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
formatted_prompt = formatted_prompt.replace("{output_format}", terms_list_parser.get_format_instructions())
- llm = ChatOpenAI(
- api_key=APIKeyManager.get_api_key(),
- openai_api_base=os.getenv("OPENAI_API_BASE"),
- model_name=os.getenv("MODEL_NAME"),
- temperature=TEMPERATURE,
- top_p=TOP_P
- )
- llm.with_structured_output(TermList)
-
# 异步调用LLM
- response = await llm.ainvoke(formatted_prompt)
+ response = await self._llm.ainvoke(formatted_prompt, extra_body={"enable_thinking": False})
# 尝试使用Pydantic解析器解析TermList
response.content = response.content.strip()
@@ -356,16 +344,7 @@ class AsyncIntentRecognizer:
"""
try:
-
- llm = ChatOpenAI(
- api_key=APIKeyManager.get_api_key(),
- openai_api_base=os.getenv("OPENAI_API_BASE"),
- model_name=os.getenv("MODEL_NAME"),
- temperature=TEMPERATURE,
- top_p=TOP_P
- )
-
- response = await llm.ainvoke(prompt, response_format={"type": "json_object"})
+ response = await self._llm.ainvoke(prompt, response_format={"type": "json_object"}, extra_body={"enable_thinking": False})
response.content = response.content.strip()
clean_output = re.sub(r'.*?', '', response.content, flags=re.DOTALL)
parsed_output = JsonOutputParser().parse(clean_output)
@@ -405,17 +384,8 @@ class AsyncIntentRecognizer:
context=context)
# 解析输出
try:
- llm = ChatOpenAI(
- api_key=APIKeyManager.get_api_key(),
- openai_api_base=os.getenv("OPENAI_API_BASE"),
- model_name=os.getenv("MODEL_NAME"),
- temperature=TEMPERATURE,
- top_p=TOP_P
- )
- llm.with_structured_output(QueryRewrite)
-
# 异步调用LLM
- response = await llm.ainvoke(formatted_prompt)
+ response = await self._llm.ainvoke(formatted_prompt, extra_body={"enable_thinking": False})
response.content = response.content.strip()
clean_output = re.sub(r'.*?', '', response.content, flags=re.DOTALL)
parsed_output = query_rewrite_parser.parse(clean_output)
@@ -659,18 +629,8 @@ class AsyncIntentRecognizer:
previous_slots=json.dumps(previous_slots,ensure_ascii=False),
)
try:
- llm = ChatOpenAI(
- api_key=APIKeyManager.get_api_key(),
- openai_api_base=os.getenv("OPENAI_API_BASE"),
- model_name=os.getenv("MODEL_NAME"),
- temperature=TEMPERATURE,
- top_p=TOP_P
- )
- llm.with_structured_output(slot_model_class)
-
-
# 异步调用LLM
- response = await llm.ainvoke(formatted_prompt)
+ response = await self._llm.ainvoke(formatted_prompt, extra_body={"enable_thinking": False})
response.content = response.content.strip()
clean_output = re.sub(r'.*?', '', response.content, flags=re.DOTALL)
# 尝试解析LLM响应
@@ -704,17 +664,10 @@ class AsyncIntentRecognizer:
)
try:
- llm = ChatOpenAI(
- api_key=APIKeyManager.get_api_key(),
- openai_api_base=os.getenv("OPENAI_API_BASE"),
- model_name=os.getenv("MODEL_NAME"),
- temperature=TEMPERATURE,
- top_p=TOP_P
- )
- llm.with_structured_output(StepBackPrompt)
-
# 异步调用LLM
- response = await llm.ainvoke(formatted_prompt)
+ response = await self._llm.ainvoke(formatted_prompt, extra_body={"enable_thinking": False})
+
+ # 解析输出
response.content = response.content.strip()
clean_output = re.sub(r'.*?', '', response.content, flags=re.DOTALL)
parsed_output = step_back_parser.parse(clean_output)
@@ -770,18 +723,9 @@ class AsyncIntentRecognizer:
"""
try:
-
- llm = ChatOpenAI(
- api_key=APIKeyManager.get_api_key(),
- openai_api_base=os.getenv("OPENAI_API_BASE"),
- model_name=os.getenv("MODEL_NAME"),
- temperature=TEMPERATURE,
- top_p=TOP_P,
- )
-
# 异步调用LLM
start_time = time.time()
- response = await llm.ainvoke(prompt, response_format={"type": "json_object"})
+ response = await self._llm.ainvoke(prompt, response_format={"type": "json_object"}, extra_body={"enable_thinking": False})
end_time = time.time()
# 解析JSON响应
diff --git a/rag2_0/tool/ModelTool.py b/rag2_0/tool/ModelTool.py
index 0c27a95..d223f42 100755
--- a/rag2_0/tool/ModelTool.py
+++ b/rag2_0/tool/ModelTool.py
@@ -217,7 +217,7 @@ class OpenAiLLM:
except Exception as e:
raise RuntimeError(f"OpenAiLLM:invoke:error:{str(e)}.api_key:{api_key}") from e
- async def invoke_async(self, user_prompt="你是谁?", need_retry=True, **extra_kwargs):
+ async def ainvoke(self, user_prompt="你是谁?", **extra_kwargs):
"""异步调用OpenAI API"""
max_retries = 3
retry_count = 0
@@ -231,38 +231,17 @@ class OpenAiLLM:
timeout = httpx.Timeout(300.0)
kwargs["timeout"] = timeout
- if need_retry:
- while retry_count < max_retries:
- try:
- api_key = APIKeyManager.get_api_key()
- # 使用异步客户端
- async with AsyncOpenAI(api_key=api_key, base_url=self._url) as client:
- # 创建异步Completion请求
- completion = await client.chat.completions.create(
- model=self._model,
- messages=[{'role': 'user', 'content': user_prompt}],
- **kwargs
- )
- return completion.choices[0].message
-
- except Exception as e:
- retry_count += 1
- if retry_count == max_retries:
- raise RuntimeError(f"OpenAiLLM:invoke_async:error:{str(e)}.api_key:{api_key}") from e
- else:
- await asyncio.sleep(5*retry_count) # 异步等待
- else:
- try:
- api_key = APIKeyManager.get_api_key()
- async with AsyncOpenAI(api_key=api_key, base_url=self._url) as client:
- completion = await client.chat.completions.create(
- model=self._model,
- messages=[{'role': 'user', 'content': user_prompt}],
- **kwargs
- )
- return completion.choices[0].message
- except Exception as e:
- raise RuntimeError(f"OpenAiLLM:invoke_async:error:{str(e)}.api_key:{api_key}") from e
+ try:
+ api_key = APIKeyManager.get_api_key()
+ async with AsyncOpenAI(api_key=api_key, base_url=self._url) as client:
+ completion = await client.chat.completions.create(
+ model=self._model,
+ messages=[{'role': 'user', 'content': user_prompt}],
+ **kwargs
+ )
+ return completion.choices[0].message
+ except Exception as e:
+ raise RuntimeError(f"OpenAiLLM:ainvoke:error:{str(e)}") from e
if __name__ == "__main__":
# 测试重排模型
@@ -291,7 +270,7 @@ if __name__ == "__main__":
# 测试异步LLM调用
llm = OpenAiLLM()
- response = await llm.invoke_async("你好,请简单介绍一下自己")
+ response = await llm.ainvoke("你好,请简单介绍一下自己")
print(f"异步LLM响应: {response.content}")
# 如果需要运行异步测试,取消下面的注释