From 05caedc4fab9e52fbc95110904b314258f61c8cf Mon Sep 17 00:00:00 2001 From: ouyangyouzhang Date: Thu, 29 May 2025 17:19:27 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E7=A1=85=E6=B5=81=E9=87=8D?= =?UTF-8?q?=E6=8E=92=E6=A8=A1=E5=9E=8B=E7=B1=BB=EF=BC=8C=E6=9B=B4=E6=96=B0?= =?UTF-8?q?APIKeyManager=E5=AF=BC=E5=85=A5=E8=B7=AF=E5=BE=84=EF=BC=8C?= =?UTF-8?q?=E4=BC=98=E5=8C=96=E6=96=87=E6=A1=A3=E9=87=8D=E6=8E=92=E5=BA=8F?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=8C=E5=A2=9E=E5=BC=BA=E4=BB=A3=E7=A0=81?= =?UTF-8?q?=E7=BB=93=E6=9E=84=E5=92=8C=E5=8F=AF=E8=AF=BB=E6=80=A7=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- rag2_0/tool/ModelTool.py | 46 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/rag2_0/tool/ModelTool.py b/rag2_0/tool/ModelTool.py index 13c00f5..19e6696 100644 --- a/rag2_0/tool/ModelTool.py +++ b/rag2_0/tool/ModelTool.py @@ -16,7 +16,7 @@ from typing import List, Any import requests import os import logging -from .APIKeyManager import APIKeyManager +from rag2_0.tool.APIKeyManager import APIKeyManager class SiliconFlowEmbeddings(Embeddings): """SiliconFlow嵌入模型封装""" @@ -45,7 +45,45 @@ class SiliconFlowEmbeddings(Embeddings): def embed_query(self, text: str) -> List[float]: return self._embed([text])[0] - + +class SiliconFlowReRankerModel: + @staticmethod + def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]: + """ + 使用硅流重排模型对文档进行重新排序 + + Args: + query: 用户查询文本 + documents: 需要重新排序的文档列表 + top_k: 返回排序后的前k个文档 + + Returns: + List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 + """ + url = "https://api.siliconflow.cn/v1/rerank" + payload = { + "model": "BAAI/bge-reranker-v2-m3", + "query": query, + "documents": documents, + "top_n": top_k, + "max_chunks_per_doc": 1024, + "overlap_tokens": 80, + "return_documents": True + } + api_key = APIKeyManager.get_api_key() + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + try: + response = requests.post(url, json=payload, headers=headers) + response.raise_for_status() + results = response.json() + return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]] + except requests.exceptions.RequestException as e: + logging.error(f"重排序请求失败: {str(e)}") + return [] + class XinferenceReRankerModel: """重排模型封装""" @@ -83,6 +121,8 @@ class XinferenceReRankerModel: logging.error(f"重排序请求失败: {str(e)}") return [] + + class OpenAiLLM: def __init__(self, **kwargs): @@ -136,7 +176,7 @@ class OpenAiLLM: return completion.choices[0].message if __name__ == "__main__": - reranker = XinferenceReRankerModel() + reranker = SiliconFlowReRankerModel() query = "什么是AI" documents = ["AI是人工智能", "AI是机器学习", "AI是深度学习"] results = reranker.rerank(query, documents)