refactor(embedding/reranker): 重构模型工具类使用环境变量配置

更新.gitignore文件,添加新的数据库文件
在.env中添加EMBEDDING_MODEL_NAME和XINFERENCE_URL配置
重构SiliconFlowEmbeddings和XinferenceReRankerModel使用环境变量
优化DifyCompareTest异常处理和输入验证
修改测试文件路径和并发工作数
This commit is contained in:
2025-08-15 10:34:30 +08:00
parent 1cde82cc86
commit 8b9ea73b3b
4 changed files with 43 additions and 101 deletions
+14 -87
View File
@@ -23,10 +23,11 @@ from rag2_0.tool.APIKeyManager import APIKeyManager
class SiliconFlowEmbeddings(Embeddings):
"""SiliconFlow嵌入模型封装"""
def __init__(self, api_key: str, model: str = "bge-m3"):
def __init__(self, api_key: str, model: str = os.getenv("EMBEDDING_MODEL_NAME", "bge-m3")):
self.api_key = api_key
self.model = model
self.url = "http://10.1.16.39:9995/v1/embeddings"
base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995")
self.url = urljoin(base_url.rstrip('/') + '/', 'v1/embeddings')
self.headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
@@ -70,83 +71,7 @@ class SiliconFlowEmbeddings(Embeddings):
"""异步嵌入单个查询"""
result = await self._embed_async([text])
return result[0]
class SiliconFlowReRankerModel:
@staticmethod
def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]:
"""
使用硅流重排模型对文档进行重新排序
Args:
query: 用户查询文本
documents: 需要重新排序的文档列表
top_k: 返回排序后的前k个文档
Returns:
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
"""
url = "https://api.siliconflow.cn/v1/rerank"
payload = {
"model": "BAAI/bge-reranker-v2-m3",
"query": query,
"documents": documents,
"top_n": top_k,
"max_chunks_per_doc": 1024,
"overlap_tokens": 80,
"return_documents": True
}
api_key = APIKeyManager.get_api_key()
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
try:
response = requests.post(url, json=payload, headers=headers, timeout=300)
response.raise_for_status()
results = response.json()
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
except requests.exceptions.RequestException as e:
logging.error(f"重排序请求失败: {str(e)}", exc_info=True)
return []
@staticmethod
async def rerank_async(query: str, documents: List[str], top_k: int = 10) -> List[str]:
"""
使用硅流重排模型对文档进行异步重新排序
Args:
query: 用户查询文本
documents: 需要重新排序的文档列表
top_k: 返回排序后的前k个文档
Returns:
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
"""
url = "https://api.siliconflow.cn/v1/rerank"
payload = {
"model": "BAAI/bge-reranker-v2-m3",
"query": query,
"documents": documents,
"top_n": top_k,
"max_chunks_per_doc": 1024,
"overlap_tokens": 80,
"return_documents": True
}
api_key = APIKeyManager.get_api_key()
headers = {
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
}
try:
async with httpx.AsyncClient(timeout=300) as client:
response = await client.post(url, json=payload, headers=headers)
response.raise_for_status()
results = response.json()
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
except httpx.RequestError as e:
logging.error(f"异步重排序请求失败: {str(e)}", exc_info=True)
return []
class XinferenceReRankerModel:
"""重排模型封装"""
@@ -163,17 +88,18 @@ class XinferenceReRankerModel:
Returns:
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
"""
url = "http://10.1.16.39:9995/v1/rerank"
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": "bge-reranker-v2-m3"}
base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995")
model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3")
rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank')
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name}
headers = {
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
"Content-Type": "application/json"
}
try:
response = requests.post(url, json=params, headers=headers, timeout=300)
response = requests.post(rerank_url, json=params, headers=headers, timeout=300)
response.raise_for_status() # 检查响应状态
results = response.json()
@@ -197,9 +123,10 @@ class XinferenceReRankerModel:
Returns:
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
"""
url = "http://10.1.16.39:9995/v1/rerank"
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": "bge-reranker-v2-m3"}
base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995")
rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank')
model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3")
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name}
headers = {
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
"Content-Type": "application/json"
@@ -207,7 +134,7 @@ class XinferenceReRankerModel:
try:
async with httpx.AsyncClient(timeout=300) as client:
response = await client.post(url, json=params, headers=headers)
response = await client.post(rerank_url, json=params, headers=headers)
response.raise_for_status() # 检查响应状态
results = response.json()