上传问题改写、意图识别模块代码
This commit is contained in:
@@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: ModelTool.py
|
||||
Date: 2025-05-15
|
||||
Author: oyyz
|
||||
Description: 模型工具类
|
||||
"""
|
||||
|
||||
from openai import OpenAI
|
||||
import httpx
|
||||
import time
|
||||
import logging # 导入 logging 模块
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from typing import List, Any
|
||||
import requests
|
||||
import os
|
||||
import logging
|
||||
from .APIKeyManager import APIKeyManager
|
||||
|
||||
class SiliconFlowEmbeddings(Embeddings):
|
||||
"""SiliconFlow嵌入模型封装"""
|
||||
def __init__(self, api_key: str, model: str = "bge-m3"):
|
||||
self.api_key = api_key
|
||||
self.model = model
|
||||
self.url = "http://10.1.16.39:9995/v1/embeddings"
|
||||
self.headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
def _embed(self, input: List[str]) -> List[List[float]]:
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"input": input,
|
||||
"encoding_format": "float"
|
||||
}
|
||||
response = requests.post(self.url, json=payload, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
return [item["embedding"] for item in data["data"]]
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
return self._embed(texts)
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
return self._embed([text])[0]
|
||||
|
||||
class XinferenceReRankerModel:
|
||||
"""重排模型封装"""
|
||||
|
||||
@staticmethod
|
||||
def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]:
|
||||
"""
|
||||
使用重排序模型对文档进行重新排序
|
||||
|
||||
Args:
|
||||
query: 用户查询文本
|
||||
documents: 需要重新排序的文档列表
|
||||
top_k: 返回排序后的前k个文档
|
||||
|
||||
Returns:
|
||||
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
||||
"""
|
||||
url = "http://10.1.16.39:9995/v1/rerank"
|
||||
|
||||
|
||||
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": os.getenv("RERANKER_MODEL_NAME")}
|
||||
headers = {
|
||||
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(url, json=params, headers=headers)
|
||||
response.raise_for_status() # 检查响应状态
|
||||
results = response.json()
|
||||
|
||||
# 返回重排序后的文档列表
|
||||
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
||||
|
||||
except requests.exceptions.RequestException as e:
|
||||
logging.error(f"重排序请求失败: {str(e)}")
|
||||
return []
|
||||
|
||||
class OpenAiLLM:
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
if kwargs.get("api_key") == None or kwargs.get("base_url") == None or kwargs.get("model") == None:
|
||||
raise ValueError("api_key, base_url, model 不能为空")
|
||||
|
||||
self._api_key = kwargs.get("api_key")
|
||||
self._url = kwargs.get("base_url")
|
||||
self._model = kwargs.get("model")
|
||||
|
||||
kwargs.pop("api_key")
|
||||
kwargs.pop("base_url")
|
||||
kwargs.pop("model")
|
||||
self._kwargs = kwargs
|
||||
|
||||
def invoke(self, user_prompt="你是谁?", need_retry=True):
|
||||
# 初始化 OpenAI 客户端
|
||||
api_key = APIKeyManager.get_api_key()
|
||||
client = OpenAI(api_key=api_key, base_url=self._url)
|
||||
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
if need_retry:
|
||||
while retry_count < max_retries:
|
||||
try:
|
||||
# 创建 Completion 请求. 超时120s
|
||||
completion = client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[{'role': 'user', 'content': user_prompt}],
|
||||
timeout=httpx.Timeout(300.0),
|
||||
**self._kwargs
|
||||
)
|
||||
return completion.choices[0].message
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count == max_retries:
|
||||
logging.error(f"LLM 重试{max_retries}次后仍然失败: {e}")
|
||||
return ""
|
||||
else:
|
||||
time.sleep(5*retry_count) # 重试前等待1秒
|
||||
else:
|
||||
# 创建 Completion 请求. 超时120s
|
||||
completion = client.chat.completions.create(
|
||||
model=self._model,
|
||||
messages=[{'role': 'user', 'content': user_prompt}],
|
||||
timeout=httpx.Timeout(300.0),
|
||||
**self._kwargs
|
||||
)
|
||||
return completion.choices[0].message
|
||||
|
||||
if __name__ == "__main__":
|
||||
reranker = XinferenceReRankerModel()
|
||||
query = "什么是AI"
|
||||
documents = ["AI是人工智能", "AI是机器学习", "AI是深度学习"]
|
||||
results = reranker.rerank(query, documents)
|
||||
print(results)
|
||||
Reference in New Issue
Block a user