#!/usr/bin/env python # -*- coding: utf-8 -*- """ File: ModelTool.py Date: 2025-05-15 Author: oyyz Description: 模型工具类 """ from openai import OpenAI import httpx import time import logging # 导入 logging 模块 from langchain.embeddings.base import Embeddings from typing import List, Any import requests import os import logging from rag2_0.tool.APIKeyManager import APIKeyManager class SiliconFlowEmbeddings(Embeddings): """SiliconFlow嵌入模型封装""" def __init__(self, api_key: str, model: str = "bge-m3"): self.api_key = api_key self.model = model self.url = "http://10.1.16.39:9995/v1/embeddings" self.headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } def _embed(self, input: List[str]) -> List[List[float]]: payload = { "model": self.model, "input": input, "encoding_format": "float" } response = requests.post(self.url, json=payload, headers=self.headers) response.raise_for_status() data = response.json() return [item["embedding"] for item in data["data"]] def embed_documents(self, texts: List[str]) -> List[List[float]]: return self._embed(texts) def embed_query(self, text: str) -> List[float]: return self._embed([text])[0] class SiliconFlowReRankerModel: @staticmethod def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]: """ 使用硅流重排模型对文档进行重新排序 Args: query: 用户查询文本 documents: 需要重新排序的文档列表 top_k: 返回排序后的前k个文档 Returns: List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 """ url = "https://api.siliconflow.cn/v1/rerank" payload = { "model": "BAAI/bge-reranker-v2-m3", "query": query, "documents": documents, "top_n": top_k, "max_chunks_per_doc": 1024, "overlap_tokens": 80, "return_documents": True } api_key = APIKeyManager.get_api_key() headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json" } try: response = requests.post(url, json=payload, headers=headers) response.raise_for_status() results = response.json() return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]] except requests.exceptions.RequestException as e: logging.error(f"重排序请求失败: {str(e)}") return [] class XinferenceReRankerModel: """重排模型封装""" @staticmethod def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]: """ 使用重排序模型对文档进行重新排序 Args: query: 用户查询文本 documents: 需要重新排序的文档列表 top_k: 返回排序后的前k个文档 Returns: List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 """ url = "http://10.1.16.39:9995/v1/rerank" params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": os.getenv("RERANKER_MODEL_NAME")} headers = { "Authorization": "Bearer ", # 这里需要替换为实际的token "Content-Type": "application/json" } try: response = requests.post(url, json=params, headers=headers) response.raise_for_status() # 检查响应状态 results = response.json() # 返回重排序后的文档列表 return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]] except requests.exceptions.RequestException as e: logging.error(f"XinferenceReRankerModel重排序请求失败: {str(e)}") return [] class OpenAiLLM: def __init__(self, **kwargs): if kwargs.get("api_key") == None or kwargs.get("base_url") == None or kwargs.get("model") == None: raise ValueError("api_key, base_url, model 不能为空") self._api_key = kwargs.get("api_key") self._url = kwargs.get("base_url") self._model = kwargs.get("model") kwargs.pop("api_key") kwargs.pop("base_url") kwargs.pop("model") self._kwargs = kwargs def invoke(self, user_prompt="你是谁?", need_retry=True): # 初始化 OpenAI 客户端 api_key = APIKeyManager.get_api_key() client = OpenAI(api_key=api_key, base_url=self._url) max_retries = 3 retry_count = 0 if need_retry: while retry_count < max_retries: try: # 创建 Completion 请求. 超时120s completion = client.chat.completions.create( model=self._model, messages=[{'role': 'user', 'content': user_prompt}], timeout=httpx.Timeout(300.0), **self._kwargs ) return completion.choices[0].message except Exception as e: retry_count += 1 if retry_count == max_retries: logging.error(f"LLM 重试{max_retries}次后仍然失败: {e}") return "" else: time.sleep(5*retry_count) # 重试前等待1秒 else: # 创建 Completion 请求. 超时120s completion = client.chat.completions.create( model=self._model, messages=[{'role': 'user', 'content': user_prompt}], timeout=httpx.Timeout(300.0), **self._kwargs ) return completion.choices[0].message if __name__ == "__main__": # 测试重排模型 reranker = SiliconFlowReRankerModel() # 测试用例1:简单问题 query = "他想做什么" documents = ["她想去公园跑步", "她想换一个新手机", "明天她想出去旅游"] results = reranker.rerank(query, documents) print(f"测试用例1 - 查询:{query}") for idx, item in enumerate(results): print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") print("-" * 50) # 测试用例2:技术问题 query = "Python如何处理JSON数据" documents = [ "Python中可以使用json模块来处理JSON数据,例如json.loads()将JSON字符串转换为字典", "Java提供了多种处理JSON的库,比如Jackson和Gson", "在Python中,可以使用pandas库来分析CSV数据", "JavaScript可以使用JSON.parse()方法解析JSON字符串" ] results = reranker.rerank(query, documents) print(f"测试用例2 - 查询:{query}") for idx, item in enumerate(results): print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") print("-" * 50) # 测试用例3:医疗问题 query = "高血压的症状有哪些" documents = [ "高血压的常见症状包括头痛、头晕、耳鸣和视力模糊", "糖尿病的症状包括多饮、多尿和体重减轻", "心脏病的症状通常包括胸痛、呼吸急促和疲劳", "高血压患者应该定期监测血压,保持健康的生活方式" ] results = reranker.rerank(query, documents) print(f"测试用例3 - 查询:{query}") for idx, item in enumerate(results): print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") print("-" * 50) # 测试用例4:长文本查询和文档 query = "人工智能在医疗领域的应用及其伦理问题" documents = [ "人工智能在医疗诊断中的应用已经显示出良好的效果,例如通过分析医学影像来检测疾病。然而,这也引发了关于医生角色和责任的伦理问题。", "在教育领域,人工智能可以提供个性化学习体验,适应不同学生的学习进度和风格。", "医疗伦理问题主要包括患者隐私保护、知情同意和医疗资源分配等方面。", "人工智能技术在金融领域的应用主要集中在风险评估、欺诈检测和算法交易等方面。" ] results = reranker.rerank(query, documents) print(f"测试用例4 - 查询:{query}") for idx, item in enumerate(results): print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") print("-" * 50)