Files
QueryRewrite/rag2_0/tool/ModelTool.py
T
ouyangyouzhang 1a3fa44522 feat: 添加清单定额查询API并优化意图识别模块
新增清单定额查询API服务,支持通过名称和编码查询定额及清单信息
在意图识别模块中添加定额清单信息提取功能,并记录各步骤耗时
将SiliconFlowEmbeddings替换为XinferenceEmbeddings并添加sqlite-vss依赖
优化shell脚本的screen会话检测逻辑
2025-08-20 19:08:29 +08:00

302 lines
12 KiB
Python
Executable File

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: ModelTool.py
Date: 2025-05-15
Author: oyyz
Description: 模型工具类
"""
from openai import OpenAI
from openai import AsyncOpenAI
import httpx
import asyncio
import time
import logging # 导入 logging 模块
from langchain.embeddings.base import Embeddings
from typing import List, Any
import requests
import os
import logging
from rag2_0.tool.APIKeyManager import APIKeyManager
from urllib.parse import urljoin
class XinferenceEmbeddings(Embeddings):
"""SiliconFlow嵌入模型封装"""
def __init__(self, api_key: str, model: str = os.getenv("EMBEDDING_MODEL_NAME", "bge-m3")):
self.api_key = api_key
self.model = model
base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995")
self.url = urljoin(base_url.rstrip('/') + '/', 'v1/embeddings')
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
def _embed(self, input: List[str]) -> List[List[float]]:
payload = {
"model": self.model,
"input": input,
"encoding_format": "float"
}
response = requests.post(self.url, json=payload, headers=self.headers, timeout=300)
response.raise_for_status()
data = response.json()
return [item["embedding"] for item in data["data"]]
async def _embed_async(self, input: List[str]) -> List[List[float]]:
"""异步嵌入方法"""
payload = {
"model": self.model,
"input": input,
"encoding_format": "float"
}
async with httpx.AsyncClient(timeout=300) as client:
response = await client.post(self.url, json=payload, headers=self.headers)
response.raise_for_status()
data = response.json()
return [item["embedding"] for item in data["data"]]
def embed_documents(self, texts: List[str]) -> List[List[float]]:
return self._embed(texts)
async def embed_documents_async(self, texts: List[str]) -> List[List[float]]:
"""异步嵌入多个文档"""
return await self._embed_async(texts)
def embed_query(self, text: str) -> List[float]:
return self._embed([text])[0]
async def embed_query_async(self, text: str) -> List[float]:
"""异步嵌入单个查询"""
result = await self._embed_async([text])
return result[0]
class XinferenceReRankerModel:
"""重排模型封装"""
@staticmethod
def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]:
"""
使用重排序模型对文档进行重新排序
Args:
query: 用户查询文本
documents: 需要重新排序的文档列表
top_k: 返回排序后的前k个文档
Returns:
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
"""
base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995")
model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3")
rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank')
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name}
headers = {
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
"Content-Type": "application/json"
}
try:
response = requests.post(rerank_url, json=params, headers=headers, timeout=300)
response.raise_for_status() # 检查响应状态
results = response.json()
# 返回重排序后的文档列表
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
except requests.exceptions.RequestException as e:
logging.error(f"XinferenceReRankerModel重排序请求失败: {str(e)}")
return []
@staticmethod
async def rerank_async(query: str, documents: List[str], top_k: int = 10) -> List[str]:
"""
使用重排序模型对文档进行异步重新排序
Args:
query: 用户查询文本
documents: 需要重新排序的文档列表
top_k: 返回排序后的前k个文档
Returns:
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
"""
base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995")
rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank')
model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3")
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name}
headers = {
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
"Content-Type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=300) as client:
response = await client.post(rerank_url, json=params, headers=headers)
response.raise_for_status() # 检查响应状态
results = response.json()
# 返回重排序后的文档列表
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
except httpx.RequestError as e:
logging.error(f"XinferenceReRankerModel异步重排序请求失败: {str(e)}")
return []
class OpenAiLLM:
def __init__(self, **kwargs):
if "api_key" in kwargs:
self._api_key = kwargs.get("api_key")
kwargs.pop("api_key")
if "base_url" in kwargs:
self._url = kwargs.get("base_url")
kwargs.pop("base_url")
else:
self._url = os.getenv("OPENAI_API_BASE")
if "model" in kwargs:
self._model = kwargs.get("model")
kwargs.pop("model")
else:
self._model = os.getenv("MODEL_NAME")
self._kwargs = kwargs
def invoke(self, user_prompt="你是谁?", need_retry=True, api_key:str = None, **extra_kwargs):
# 初始化 OpenAI 客户端
max_retries = 3
retry_count = 0
# 合并额外的kwargs与self._kwargs
kwargs = {**self._kwargs}
if extra_kwargs:
kwargs.update(extra_kwargs)
if "timeout" not in self._kwargs:
timeout = httpx.Timeout(300.0)
self._kwargs["timeout"] = timeout
if api_key is None:
api_key = APIKeyManager.get_api_key()
if need_retry:
while retry_count < max_retries:
try:
# 使用with语句创建客户端,确保资源会被正确释放
with OpenAI(api_key=api_key, base_url=self._url) as client:
# 创建 Completion 请求. 超时120s
completion = client.chat.completions.create(
model=self._model,
messages=[{'role': 'user', 'content': user_prompt}],
**self._kwargs
)
return completion.choices[0].message
except Exception as e:
retry_count += 1
if retry_count == max_retries:
raise RuntimeError(f"OpenAiLLM:invoke:error:{str(e)}.api_key:{api_key}") from e
else:
time.sleep(5*retry_count) # 重试前等待5秒*重试次数
else:
try:
# 创建 Completion 请求. 超时120s
# 使用with语句创建客户端,确保资源会被正确释放
with OpenAI(api_key=api_key, base_url=self._url) as client:
completion = client.chat.completions.create(
model=self._model,
messages=[{'role': 'user', 'content': user_prompt}],
**self._kwargs
)
return completion.choices[0].message
except Exception as e:
raise RuntimeError(f"OpenAiLLM:invoke:error:{str(e)}.api_key:{api_key}") from e
async def invoke_async(self, user_prompt="你是谁?", need_retry=True, **extra_kwargs):
"""异步调用OpenAI API"""
max_retries = 3
retry_count = 0
# 合并额外的kwargs与self._kwargs
kwargs = {**self._kwargs}
if extra_kwargs:
kwargs.update(extra_kwargs)
if "timeout" not in kwargs:
timeout = httpx.Timeout(300.0)
kwargs["timeout"] = timeout
if need_retry:
while retry_count < max_retries:
try:
api_key = APIKeyManager.get_api_key()
# 使用异步客户端
async with AsyncOpenAI(api_key=api_key, base_url=self._url) as client:
# 创建异步Completion请求
completion = await client.chat.completions.create(
model=self._model,
messages=[{'role': 'user', 'content': user_prompt}],
**kwargs
)
return completion.choices[0].message
except Exception as e:
retry_count += 1
if retry_count == max_retries:
raise RuntimeError(f"OpenAiLLM:invoke_async:error:{str(e)}.api_key:{api_key}") from e
else:
await asyncio.sleep(5*retry_count) # 异步等待
else:
try:
api_key = APIKeyManager.get_api_key()
async with AsyncOpenAI(api_key=api_key, base_url=self._url) as client:
completion = await client.chat.completions.create(
model=self._model,
messages=[{'role': 'user', 'content': user_prompt}],
**kwargs
)
return completion.choices[0].message
except Exception as e:
raise RuntimeError(f"OpenAiLLM:invoke_async:error:{str(e)}.api_key:{api_key}") from e
if __name__ == "__main__":
# 测试重排模型
reranker = SiliconFlowReRankerModel()
# 测试用例1:简单问题
query = "如何通过【电力经济评价软件】的【打开】功能加载工程文件?"
documents = []
results = reranker.rerank(query, documents)
print(f"测试用例1 - 查询:{query}")
for idx, item in enumerate(results):
print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}")
print("-" * 50)
# 异步测试示例
async def test_async():
# 测试异步嵌入
api_key = APIKeyManager.get_api_key()
embeddings = XinferenceEmbeddings(api_key=api_key)
query_embedding = await embeddings.embed_query_async("测试查询")
print(f"异步嵌入向量维度: {len(query_embedding)}")
# 测试异步重排序
results = await SiliconFlowReRankerModel.rerank_async(query, documents)
print(f"异步重排序结果数量: {len(results)}")
# 测试异步LLM调用
llm = OpenAiLLM()
response = await llm.invoke_async("你好,请简单介绍一下自己")
print(f"异步LLM响应: {response.content}")
# 如果需要运行异步测试,取消下面的注释
# import asyncio
# asyncio.run(test_async())