新增异步意图识别器和相关功能,优化意图识别和槽位填充逻辑,支持异步处理和多线程检索,改进API调用的错误处理和日志记录,增强文档检索和关键词提取功能。
This commit is contained in:
@@ -8,7 +8,9 @@ Description: 模型工具类
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
import httpx
|
||||
import asyncio
|
||||
import time
|
||||
import logging # 导入 logging 模块
|
||||
from langchain.embeddings.base import Embeddings
|
||||
@@ -41,12 +43,34 @@ class SiliconFlowEmbeddings(Embeddings):
|
||||
data = response.json()
|
||||
return [item["embedding"] for item in data["data"]]
|
||||
|
||||
async def _embed_async(self, input: List[str]) -> List[List[float]]:
|
||||
"""异步嵌入方法"""
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": input,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(self.url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [item["embedding"] for item in data["data"]]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._embed(texts)
|
||||
|
||||
async def embed_documents_async(self, texts: List[str]) -> List[List[float]]:
|
||||
"""异步嵌入多个文档"""
|
||||
return await self._embed_async(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._embed([text])[0]
|
||||
|
||||
async def embed_query_async(self, text: str) -> List[float]:
|
||||
"""异步嵌入单个查询"""
|
||||
result = await self._embed_async([text])
|
||||
return result[0]
|
||||
|
||||
class SiliconFlowReRankerModel:
|
||||
@staticmethod
|
||||
def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]:
|
||||
@@ -84,6 +108,44 @@ class SiliconFlowReRankerModel:
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"重排序请求失败: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def rerank_async(query: str, documents: List[str], top_k: int = 10) -> List[str]:
|
||||
"""
|
||||
使用硅流重排模型对文档进行异步重新排序
|
||||
|
||||
Args:
|
||||
query: 用户查询文本
|
||||
documents: 需要重新排序的文档列表
|
||||
top_k: 返回排序后的前k个文档
|
||||
|
||||
Returns:
|
||||
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
||||
"""
|
||||
url = "https://api.siliconflow.cn/v1/rerank"
|
||||
payload = {
|
||||
"model": "BAAI/bge-reranker-v2-m3",
|
||||
"query": query,
|
||||
"documents": documents,
|
||||
"top_n": top_k,
|
||||
"max_chunks_per_doc": 1024,
|
||||
"overlap_tokens": 80,
|
||||
"return_documents": True
|
||||
}
|
||||
api_key = APIKeyManager.get_api_key()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json=payload, headers=headers)
|
||||
response.raise_for_status()
|
||||
results = response.json()
|
||||
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
||||
except httpx.RequestError as e:
|
||||
logging.error(f"异步重排序请求失败: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
class XinferenceReRankerModel:
|
||||
"""重排模型封装"""
|
||||
@@ -122,6 +184,39 @@ class XinferenceReRankerModel:
|
||||
logging.error(f"XinferenceReRankerModel重排序请求失败: {str(e)}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def rerank_async(query: str, documents: List[str], top_k: int = 10) -> List[str]:
|
||||
"""
|
||||
使用重排序模型对文档进行异步重新排序
|
||||
|
||||
Args:
|
||||
query: 用户查询文本
|
||||
documents: 需要重新排序的文档列表
|
||||
top_k: 返回排序后的前k个文档
|
||||
|
||||
Returns:
|
||||
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
||||
"""
|
||||
url = "http://172.20.0.145:9995/v1/rerank"
|
||||
|
||||
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": "bge-reranker-v2-m3"}
|
||||
headers = {
|
||||
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, json=params, headers=headers)
|
||||
response.raise_for_status() # 检查响应状态
|
||||
results = response.json()
|
||||
|
||||
# 返回重排序后的文档列表
|
||||
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
||||
|
||||
except httpx.RequestError as e:
|
||||
logging.error(f"XinferenceReRankerModel异步重排序请求失败: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
class OpenAiLLM:
|
||||
@@ -189,6 +284,47 @@ class OpenAiLLM:
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"OpenAiLLM:invoke:error:{str(e)}.api_key:{api_key}") from e
|
||||
|
||||
async def invoke_async(self, user_prompt="你是谁?", need_retry=True):
|
||||
"""异步调用OpenAI API"""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
if "timeout" not in self._kwargs:
|
||||
timeout = httpx.Timeout(300.0)
|
||||
self._kwargs["timeout"] = timeout
|
||||
|
||||
if need_retry:
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
api_key = APIKeyManager.get_api_key()
|
||||
# 使用异步客户端
|
||||
async with AsyncOpenAI(api_key=api_key, base_url=self._url) as client:
|
||||
# 创建异步Completion请求
|
||||
completion = await client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[{'role': 'user', 'content': user_prompt}],
|
||||
**self._kwargs
|
||||
)
|
||||
return completion.choices[0].message
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count == max_retries:
|
||||
raise RuntimeError(f"OpenAiLLM:invoke_async:error:{str(e)}.api_key:{api_key}") from e
|
||||
else:
|
||||
await asyncio.sleep(5*retry_count) # 异步等待
|
||||
else:
|
||||
try:
|
||||
api_key = APIKeyManager.get_api_key()
|
||||
async with AsyncOpenAI(api_key=api_key, base_url=self._url) as client:
|
||||
completion = await client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[{'role': 'user', 'content': user_prompt}],
|
||||
**self._kwargs
|
||||
)
|
||||
return completion.choices[0].message
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"OpenAiLLM:invoke_async:error:{str(e)}.api_key:{api_key}") from e
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试重排模型
|
||||
reranker = SiliconFlowReRankerModel()
|
||||
@@ -202,4 +338,25 @@ if __name__ == "__main__":
|
||||
print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}")
|
||||
print("-" * 50)
|
||||
|
||||
# 异步测试示例
|
||||
async def test_async():
|
||||
# 测试异步嵌入
|
||||
api_key = APIKeyManager.get_api_key()
|
||||
embeddings = SiliconFlowEmbeddings(api_key=api_key)
|
||||
query_embedding = await embeddings.embed_query_async("测试查询")
|
||||
print(f"异步嵌入向量维度: {len(query_embedding)}")
|
||||
|
||||
# 测试异步重排序
|
||||
results = await SiliconFlowReRankerModel.rerank_async(query, documents)
|
||||
print(f"异步重排序结果数量: {len(results)}")
|
||||
|
||||
# 测试异步LLM调用
|
||||
llm = OpenAiLLM()
|
||||
response = await llm.invoke_async("你好,请简单介绍一下自己")
|
||||
print(f"异步LLM响应: {response.content}")
|
||||
|
||||
# 如果需要运行异步测试,取消下面的注释
|
||||
# import asyncio
|
||||
# asyncio.run(test_async())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user