添加硅流重排模型类,更新APIKeyManager导入路径,优化文档重排序逻辑,增强代码结构和可读性。
This commit is contained in:
@@ -16,7 +16,7 @@ from typing import List, Any
|
|||||||
import requests
|
import requests
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
from .APIKeyManager import APIKeyManager
|
from rag2_0.tool.APIKeyManager import APIKeyManager
|
||||||
|
|
||||||
class SiliconFlowEmbeddings(Embeddings):
|
class SiliconFlowEmbeddings(Embeddings):
|
||||||
"""SiliconFlow嵌入模型封装"""
|
"""SiliconFlow嵌入模型封装"""
|
||||||
@@ -46,6 +46,44 @@ class SiliconFlowEmbeddings(Embeddings):
|
|||||||
def embed_query(self, text: str) -> List[float]:
|
def embed_query(self, text: str) -> List[float]:
|
||||||
return self._embed([text])[0]
|
return self._embed([text])[0]
|
||||||
|
|
||||||
|
class SiliconFlowReRankerModel:
|
||||||
|
@staticmethod
|
||||||
|
def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]:
|
||||||
|
"""
|
||||||
|
使用硅流重排模型对文档进行重新排序
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户查询文本
|
||||||
|
documents: 需要重新排序的文档列表
|
||||||
|
top_k: 返回排序后的前k个文档
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
||||||
|
"""
|
||||||
|
url = "https://api.siliconflow.cn/v1/rerank"
|
||||||
|
payload = {
|
||||||
|
"model": "BAAI/bge-reranker-v2-m3",
|
||||||
|
"query": query,
|
||||||
|
"documents": documents,
|
||||||
|
"top_n": top_k,
|
||||||
|
"max_chunks_per_doc": 1024,
|
||||||
|
"overlap_tokens": 80,
|
||||||
|
"return_documents": True
|
||||||
|
}
|
||||||
|
api_key = APIKeyManager.get_api_key()
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
response = requests.post(url, json=payload, headers=headers)
|
||||||
|
response.raise_for_status()
|
||||||
|
results = response.json()
|
||||||
|
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
logging.error(f"重排序请求失败: {str(e)}")
|
||||||
|
return []
|
||||||
|
|
||||||
class XinferenceReRankerModel:
|
class XinferenceReRankerModel:
|
||||||
"""重排模型封装"""
|
"""重排模型封装"""
|
||||||
|
|
||||||
@@ -83,6 +121,8 @@ class XinferenceReRankerModel:
|
|||||||
logging.error(f"重排序请求失败: {str(e)}")
|
logging.error(f"重排序请求失败: {str(e)}")
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAiLLM:
|
class OpenAiLLM:
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@@ -136,7 +176,7 @@ class OpenAiLLM:
|
|||||||
return completion.choices[0].message
|
return completion.choices[0].message
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
reranker = XinferenceReRankerModel()
|
reranker = SiliconFlowReRankerModel()
|
||||||
query = "什么是AI"
|
query = "什么是AI"
|
||||||
documents = ["AI是人工智能", "AI是机器学习", "AI是深度学习"]
|
documents = ["AI是人工智能", "AI是机器学习", "AI是深度学习"]
|
||||||
results = reranker.rerank(query, documents)
|
results = reranker.rerank(query, documents)
|
||||||
|
|||||||
Reference in New Issue
Block a user