#!/usr/bin/env python # -*- coding: utf-8 -*- """ File: ModelTool.py Date: 2025-05-15 Author: oyyz Description: 模型工具类 """ from openai import OpenAI from openai import AsyncOpenAI import httpx import asyncio import time import logging # 导入 logging 模块 from langchain.embeddings.base import Embeddings from typing import List, Any import requests import os import logging from rag2_0.tool.APIKeyManager import APIKeyManager from urllib.parse import urljoin class XinferenceEmbeddings(Embeddings): """SiliconFlow嵌入模型封装""" def __init__(self, api_key: str, model: str = os.getenv("EMBEDDING_MODEL_NAME", "bge-m3")): self.api_key = api_key self.model = model base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995") self.url = urljoin(base_url.rstrip('/') + '/', 'v1/embeddings') self.headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } def _embed(self, input: List[str]) -> List[List[float]]: payload = { "model": self.model, "input": input, "encoding_format": "float" } response = requests.post(self.url, json=payload, headers=self.headers, timeout=300) response.raise_for_status() data = response.json() return [item["embedding"] for item in data["data"]] async def _embed_async(self, input: List[str]) -> List[List[float]]: """异步嵌入方法""" payload = { "model": self.model, "input": input, "encoding_format": "float" } async with httpx.AsyncClient(timeout=300) as client: response = await client.post(self.url, json=payload, headers=self.headers) response.raise_for_status() data = response.json() return [item["embedding"] for item in data["data"]] def embed_documents(self, texts: List[str]) -> List[List[float]]: return self._embed(texts) async def embed_documents_async(self, texts: List[str]) -> List[List[float]]: """异步嵌入多个文档""" return await self._embed_async(texts) def embed_query(self, text: str) -> List[float]: return self._embed([text])[0] async def embed_query_async(self, text: str) -> List[float]: """异步嵌入单个查询""" result = await self._embed_async([text]) return result[0] class XinferenceReRankerModel: """重排模型封装""" @staticmethod def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]: """ 使用重排序模型对文档进行重新排序 Args: query: 用户查询文本 documents: 需要重新排序的文档列表 top_k: 返回排序后的前k个文档 Returns: List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 """ base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995") model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3") rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank') params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name} headers = { "Authorization": "Bearer ", # 这里需要替换为实际的token "Content-Type": "application/json" } try: response = requests.post(rerank_url, json=params, headers=headers, timeout=300) response.raise_for_status() # 检查响应状态 results = response.json() # 返回重排序后的文档列表 return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]] except requests.exceptions.RequestException as e: logging.error(f"XinferenceReRankerModel重排序请求失败: {str(e)}") return [] @staticmethod async def rerank_async(query: str, documents: List[str], top_k: int = 10) -> List[str]: """ 使用重排序模型对文档进行异步重新排序 Args: query: 用户查询文本 documents: 需要重新排序的文档列表 top_k: 返回排序后的前k个文档 Returns: List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引 """ base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995") rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank') model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3") params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name} headers = { "Authorization": "Bearer ", # 这里需要替换为实际的token "Content-Type": "application/json" } try: async with httpx.AsyncClient(timeout=300) as client: response = await client.post(rerank_url, json=params, headers=headers) response.raise_for_status() # 检查响应状态 results = response.json() # 返回重排序后的文档列表 return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]] except httpx.RequestError as e: logging.error(f"XinferenceReRankerModel异步重排序请求失败: {str(e)}") return [] class OpenAiLLM: def __init__(self, **kwargs): if "api_key" in kwargs: self._api_key = kwargs.get("api_key") kwargs.pop("api_key") if "base_url" in kwargs: self._url = kwargs.get("base_url") kwargs.pop("base_url") else: self._url = os.getenv("OPENAI_API_BASE") if "model" in kwargs: self._model = kwargs.get("model") kwargs.pop("model") else: self._model = os.getenv("MODEL_NAME") self._kwargs = kwargs def invoke(self, user_prompt="你是谁?", need_retry=True, api_key:str = None, **extra_kwargs): # 初始化 OpenAI 客户端 max_retries = 3 retry_count = 0 # 合并额外的kwargs与self._kwargs kwargs = {**self._kwargs} if extra_kwargs: kwargs.update(extra_kwargs) if "timeout" not in self._kwargs: timeout = httpx.Timeout(300.0) self._kwargs["timeout"] = timeout if api_key is None: api_key = APIKeyManager.get_api_key() if need_retry: while retry_count < max_retries: try: # 使用with语句创建客户端,确保资源会被正确释放 with OpenAI(api_key=api_key, base_url=self._url) as client: # 创建 Completion 请求. 超时120s completion = client.chat.completions.create( model=self._model, messages=[{'role': 'user', 'content': user_prompt}], **self._kwargs ) return completion.choices[0].message except Exception as e: retry_count += 1 if retry_count == max_retries: raise RuntimeError(f"OpenAiLLM:invoke:error:{str(e)}.api_key:{api_key}") from e else: time.sleep(5*retry_count) # 重试前等待5秒*重试次数 else: try: # 创建 Completion 请求. 超时120s # 使用with语句创建客户端,确保资源会被正确释放 with OpenAI(api_key=api_key, base_url=self._url) as client: completion = client.chat.completions.create( model=self._model, messages=[{'role': 'user', 'content': user_prompt}], **self._kwargs ) return completion.choices[0].message except Exception as e: raise RuntimeError(f"OpenAiLLM:invoke:error:{str(e)}.api_key:{api_key}") from e async def invoke_async(self, user_prompt="你是谁?", need_retry=True, **extra_kwargs): """异步调用OpenAI API""" max_retries = 3 retry_count = 0 # 合并额外的kwargs与self._kwargs kwargs = {**self._kwargs} if extra_kwargs: kwargs.update(extra_kwargs) if "timeout" not in kwargs: timeout = httpx.Timeout(300.0) kwargs["timeout"] = timeout if need_retry: while retry_count < max_retries: try: api_key = APIKeyManager.get_api_key() # 使用异步客户端 async with AsyncOpenAI(api_key=api_key, base_url=self._url) as client: # 创建异步Completion请求 completion = await client.chat.completions.create( model=self._model, messages=[{'role': 'user', 'content': user_prompt}], **kwargs ) return completion.choices[0].message except Exception as e: retry_count += 1 if retry_count == max_retries: raise RuntimeError(f"OpenAiLLM:invoke_async:error:{str(e)}.api_key:{api_key}") from e else: await asyncio.sleep(5*retry_count) # 异步等待 else: try: api_key = APIKeyManager.get_api_key() async with AsyncOpenAI(api_key=api_key, base_url=self._url) as client: completion = await client.chat.completions.create( model=self._model, messages=[{'role': 'user', 'content': user_prompt}], **kwargs ) return completion.choices[0].message except Exception as e: raise RuntimeError(f"OpenAiLLM:invoke_async:error:{str(e)}.api_key:{api_key}") from e if __name__ == "__main__": # 测试重排模型 reranker = SiliconFlowReRankerModel() # 测试用例1:简单问题 query = "如何通过【电力经济评价软件】的【打开】功能加载工程文件?" documents = [] results = reranker.rerank(query, documents) print(f"测试用例1 - 查询:{query}") for idx, item in enumerate(results): print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}") print("-" * 50) # 异步测试示例 async def test_async(): # 测试异步嵌入 api_key = APIKeyManager.get_api_key() embeddings = XinferenceEmbeddings(api_key=api_key) query_embedding = await embeddings.embed_query_async("测试查询") print(f"异步嵌入向量维度: {len(query_embedding)}") # 测试异步重排序 results = await SiliconFlowReRankerModel.rerank_async(query, documents) print(f"异步重排序结果数量: {len(results)}") # 测试异步LLM调用 llm = OpenAiLLM() response = await llm.invoke_async("你好,请简单介绍一下自己") print(f"异步LLM响应: {response.content}") # 如果需要运行异步测试,取消下面的注释 # import asyncio # asyncio.run(test_async())