198 lines
11 KiB
Python
198 lines
11 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
File: ModelTool.py
|
|
Date: 2025-05-15
|
|
Author: oyyz
|
|
Description: 模型工具类
|
|
"""
|
|
|
|
from openai import OpenAI
|
|
import httpx
|
|
import time
|
|
import logging # 导入 logging 模块
|
|
from langchain.embeddings.base import Embeddings
|
|
from typing import List, Any
|
|
import requests
|
|
import os
|
|
import logging
|
|
from rag2_0.tool.APIKeyManager import APIKeyManager
|
|
|
|
class SiliconFlowEmbeddings(Embeddings):
|
|
"""SiliconFlow嵌入模型封装"""
|
|
def __init__(self, api_key: str, model: str = "bge-m3"):
|
|
self.api_key = api_key
|
|
self.model = model
|
|
self.url = "http://10.1.16.39:9995/v1/embeddings"
|
|
self.headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
def _embed(self, input: List[str]) -> List[List[float]]:
|
|
payload = {
|
|
"model": self.model,
|
|
"input": input,
|
|
"encoding_format": "float"
|
|
}
|
|
response = requests.post(self.url, json=payload, headers=self.headers)
|
|
response.raise_for_status()
|
|
data = response.json()
|
|
return [item["embedding"] for item in data["data"]]
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
return self._embed(texts)
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
return self._embed([text])[0]
|
|
|
|
class SiliconFlowReRankerModel:
|
|
@staticmethod
|
|
def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]:
|
|
"""
|
|
使用硅流重排模型对文档进行重新排序
|
|
|
|
Args:
|
|
query: 用户查询文本
|
|
documents: 需要重新排序的文档列表
|
|
top_k: 返回排序后的前k个文档
|
|
|
|
Returns:
|
|
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
|
"""
|
|
url = "https://api.siliconflow.cn/v1/rerank"
|
|
payload = {
|
|
"model": "BAAI/bge-reranker-v2-m3",
|
|
"query": query,
|
|
"documents": documents,
|
|
"top_n": top_k,
|
|
"max_chunks_per_doc": 1024,
|
|
"overlap_tokens": 80,
|
|
"return_documents": True
|
|
}
|
|
api_key = APIKeyManager.get_api_key()
|
|
headers = {
|
|
"Authorization": f"Bearer {api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
try:
|
|
response = requests.post(url, json=payload, headers=headers)
|
|
response.raise_for_status()
|
|
results = response.json()
|
|
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
|
except requests.exceptions.RequestException as e:
|
|
logging.error(f"重排序请求失败: {str(e)}")
|
|
return []
|
|
|
|
class XinferenceReRankerModel:
|
|
"""重排模型封装"""
|
|
|
|
@staticmethod
|
|
def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]:
|
|
"""
|
|
使用重排序模型对文档进行重新排序
|
|
|
|
Args:
|
|
query: 用户查询文本
|
|
documents: 需要重新排序的文档列表
|
|
top_k: 返回排序后的前k个文档
|
|
|
|
Returns:
|
|
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
|
"""
|
|
url = "http://172.20.0.145:9995/v1/rerank"
|
|
|
|
|
|
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": "bge-reranker-v2-m3"}
|
|
headers = {
|
|
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
try:
|
|
response = requests.post(url, json=params, headers=headers)
|
|
response.raise_for_status() # 检查响应状态
|
|
results = response.json()
|
|
|
|
# 返回重排序后的文档列表
|
|
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
|
|
|
except requests.exceptions.RequestException as e:
|
|
logging.error(f"XinferenceReRankerModel重排序请求失败: {str(e)}")
|
|
return []
|
|
|
|
|
|
|
|
class OpenAiLLM:
|
|
|
|
def __init__(self, **kwargs):
|
|
if kwargs.get("api_key") == None or kwargs.get("base_url") == None or kwargs.get("model") == None:
|
|
raise ValueError("api_key, base_url, model 不能为空")
|
|
|
|
self._api_key = kwargs.get("api_key")
|
|
self._url = kwargs.get("base_url")
|
|
self._model = kwargs.get("model")
|
|
|
|
kwargs.pop("api_key")
|
|
kwargs.pop("base_url")
|
|
kwargs.pop("model")
|
|
self._kwargs = kwargs
|
|
|
|
def invoke(self, user_prompt="你是谁?", need_retry=True):
|
|
# 初始化 OpenAI 客户端
|
|
|
|
|
|
max_retries = 3
|
|
retry_count = 0
|
|
|
|
if need_retry:
|
|
while retry_count < max_retries:
|
|
try:
|
|
api_key = APIKeyManager.get_api_key()
|
|
client = OpenAI(api_key=api_key, base_url=self._url)
|
|
# 创建 Completion 请求. 超时120s
|
|
completion = client.chat.completions.create(
|
|
model=self._model,
|
|
messages=[{'role': 'user', 'content': user_prompt}],
|
|
timeout=httpx.Timeout(300.0),
|
|
**self._kwargs
|
|
)
|
|
return completion.choices[0].message
|
|
|
|
except Exception as e:
|
|
retry_count += 1
|
|
if retry_count == max_retries:
|
|
logging.error(f"LLM 重试{max_retries}次后仍然失败: {e}")
|
|
raise e
|
|
else:
|
|
time.sleep(5*retry_count) # 重试前等待1秒
|
|
else:
|
|
# 创建 Completion 请求. 超时120s
|
|
api_key = APIKeyManager.get_api_key()
|
|
client = OpenAI(api_key=api_key, base_url=self._url)
|
|
completion = client.chat.completions.create(
|
|
model=self._model,
|
|
messages=[{'role': 'user', 'content': user_prompt}],
|
|
timeout=httpx.Timeout(300.0),
|
|
**self._kwargs
|
|
)
|
|
return completion.choices[0].message
|
|
|
|
if __name__ == "__main__":
|
|
# 测试重排模型
|
|
reranker = SiliconFlowReRankerModel()
|
|
|
|
# 测试用例1:简单问题
|
|
query = "如何通过【电力经济评价软件】的【打开】功能加载工程文件?"
|
|
documents = ["\n# (电力建设计价通软件) (概预算工程)工程备份管理\n## 操作步骤\n**方法一:** \n\n1、查找工程:输入工程文件名称的关键字,点击“查找”按钮,可以快速定位需查找的工程;\n\n\n\n2、根据时间点找备份工程:选中对应工程文件,在右侧选中“备份时间”的备份记录,点击“还原工程”或者“另存为工程”;\n\n **还原工程:** 将工程还原保存在原路径下;\n\n **另存为工程:** 另存为一个新工程,可选择保存路径,保存后,可点击文件——打开,浏览到另存的新工程打开。\n\n注:不确定备份是否是需要时,优先建议另存为工程。\n\n\n\n **方法二:** \n\n1、点击桌面软件快捷图标,右键属性—打开文件位置,直接定位软件安装根目录。 \n\n\n\n2、在软件安装根目录,点击“数据备份”文件夹,进入到文件夹内,根据修改日期找到对应工程,右键复制粘贴至桌面。\n\n\n\n\n\n3、定位桌面复制粘贴出来的数据工程,右键\"重命名\",将bak修改成相应的文件后缀(概预算工程及施工图预算工程后缀为zwzj,招标工程及投标工程后缀为zwqd),然后点击“确定”,再通过软件的“文件”——“打开”按钮去浏览工程打开。\n",
|
|
"\n# (配网计价通D3)插件管理/全国版和专版切换\n## 使用场景\n1.打开软件提示“当前工程文件为全国版文件,请使用全国版软件打开!”,该如何打开这个工程呢?\n\n\n\n2.打开软件提示“当前工程文件为辽宁版文件,请确认是否要在全国版软件中打开?”,这是什么意思?点击“确定”又可以打开工程?\n\n\n## 知识原理\n\n## 费用去向\n\n",
|
|
"\n(电力建设计价通软件) 云造价--停用\n# 工程文件管理\n\n## 【主页】中点击“云端工程管理”,进入博微服务大厅;\n\n## 工程文件管理界面中显示云端备份的工程列表,可支持\n\n## 高级设置:可对历史版本数量进行设置,默认数量为10,可设置(5-15);\n\n## 历史版本:勾选单个工程,点击“历史版本”可查看该工程保存的不同时间节点的历史工程;\n\n## 在线查阅:可查看工程数据,仅为只读模式不支持任何编辑;\n\n## 下载;选择需要的工程点击“下载”,可下载软件版本工程;\n\n",
|
|
"\n(配网D3软件)打开工程\n\n# (配网D3软件)打开工程\n\n## 功能入口\n各界面点击“文件”按钮——“打开”按钮 \n\n\n## 操作步骤\n**打开工程:** \n\n点击“打开”按钮,浏览到工程存放位置,选中工程文件,点击“打开”即可。"]
|
|
results = reranker.rerank(query, documents)
|
|
print(f"测试用例1 - 查询:{query}")
|
|
for idx, item in enumerate(results):
|
|
print(f"{idx+1}. 文档: {item['document']}, 分数: {item['score']}")
|
|
print("-" * 50)
|
|
|
|
|