refactor(embedding/reranker): 重构模型工具类使用环境变量配置
更新.gitignore文件,添加新的数据库文件 在.env中添加EMBEDDING_MODEL_NAME和XINFERENCE_URL配置 重构SiliconFlowEmbeddings和XinferenceReRankerModel使用环境变量 优化DifyCompareTest异常处理和输入验证 修改测试文件路径和并发工作数
This commit is contained in:
@@ -3,6 +3,9 @@ OPENAI_API_BASE=https://api.siliconflow.cn/v1/
|
|||||||
MODEL_NAME=deepseek-ai/DeepSeek-V3
|
MODEL_NAME=deepseek-ai/DeepSeek-V3
|
||||||
MINI_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct-128K
|
MINI_MODEL_NAME=Qwen/Qwen2.5-72B-Instruct-128K
|
||||||
RERANKER_MODEL_NAME=bge-reranker-v2-m3
|
RERANKER_MODEL_NAME=bge-reranker-v2-m3
|
||||||
|
EMBEDDING_MODEL_NAME=bge-m3
|
||||||
|
|
||||||
|
XINFERENCE_URL=http://10.1.16.39:9995
|
||||||
|
|
||||||
DIFY_BSAE_URL=http://10.1.16.39/v1
|
DIFY_BSAE_URL=http://10.1.16.39/v1
|
||||||
DIFY_APP_KEY=app-CPoOMaGDsLRPAe9TW7Xjhszy
|
DIFY_APP_KEY=app-CPoOMaGDsLRPAe9TW7Xjhszy
|
||||||
|
|||||||
+1
-1
@@ -8,10 +8,10 @@ data/excel/*
|
|||||||
rag2_0/demo/Test*
|
rag2_0/demo/Test*
|
||||||
rag2_0/demo/test*
|
rag2_0/demo/test*
|
||||||
data/excel/*.xlsx
|
data/excel/*.xlsx
|
||||||
rag2_0/demo/ProfessionalTermAnalyzer.py
|
|
||||||
data/logs/*
|
data/logs/*
|
||||||
rag2_0/dify/Test.py
|
rag2_0/dify/Test.py
|
||||||
data/query_logs/*
|
data/query_logs/*
|
||||||
data/conversations/*
|
data/conversations/*
|
||||||
data/test*
|
data/test*
|
||||||
data/temp*
|
data/temp*
|
||||||
|
data/db/answer_logs.db
|
||||||
|
|||||||
@@ -121,6 +121,7 @@ class DifyCompareTest:
|
|||||||
time.sleep(10) # 等待1秒后重试
|
time.sleep(10) # 等待1秒后重试
|
||||||
|
|
||||||
def get_wiki_list_by_msgid(self,msg_id):
|
def get_wiki_list_by_msgid(self,msg_id):
|
||||||
|
try:
|
||||||
if msg_id is None or pd.isna(msg_id):
|
if msg_id is None or pd.isna(msg_id):
|
||||||
return ""
|
return ""
|
||||||
msg_debug_info = self.exporter.dify_tool.get_message_debug_info_by_id(msg_id)
|
msg_debug_info = self.exporter.dify_tool.get_message_debug_info_by_id(msg_id)
|
||||||
@@ -131,16 +132,27 @@ class DifyCompareTest:
|
|||||||
return ""
|
return ""
|
||||||
else:
|
else:
|
||||||
return "\n".join(list(set(wiki_list)))
|
return "\n".join(list(set(wiki_list)))
|
||||||
|
except Exception as e:
|
||||||
|
logging.error(f"获取词条列表失败: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
def process_single_row(self, index, row):
|
def process_single_row(self, index, row):
|
||||||
"""处理单行数据的方法"""
|
"""处理单行数据的方法"""
|
||||||
try:
|
try:
|
||||||
query = row["提问"]
|
query = row["提问"]
|
||||||
|
current_software = row["当前软件"]
|
||||||
|
if pd.isna(query) or len(query) == 0 or pd.isna(current_software) or len(current_software) == 0:
|
||||||
|
result_row = row.copy()
|
||||||
|
result_row["message_id"] = ''
|
||||||
|
result_row["本次回答"] = ''
|
||||||
|
result_row["回答对比"] = ''
|
||||||
|
result_row["检索到的词条"] = ''
|
||||||
|
return index, result_row
|
||||||
|
|
||||||
if "回答" in row:
|
if "回答" in row:
|
||||||
old_answer = row["回答"]
|
old_answer = row["回答"]
|
||||||
else:
|
else:
|
||||||
old_answer = ""
|
old_answer = ""
|
||||||
current_software = row["当前软件"]
|
|
||||||
|
|
||||||
inputs = {
|
inputs = {
|
||||||
"current_softname": current_software,
|
"current_softname": current_software,
|
||||||
@@ -247,14 +259,14 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# 处理第一个文件
|
# 处理第一个文件
|
||||||
excel_files = [
|
excel_files = [
|
||||||
("data/excel/300专业提问.xlsx", "data/excel/300专业提问_问答测试.xlsx"),
|
("data/excel/第一轮的专业问题.xlsx", "data/excel/第一轮的专业问题_dify.xlsx"),
|
||||||
# ("data/excel/有知识的.xlsx", "data/excel/有知识的_问答测试.xlsx")
|
# ("data/excel/有知识的.xlsx", "data/excel/有知识的_问答测试.xlsx")
|
||||||
]
|
]
|
||||||
|
|
||||||
for excel_path, save_path in excel_files:
|
for excel_path, save_path in excel_files:
|
||||||
logging.info(f"开始处理文件: {excel_path}")
|
logging.info(f"开始处理文件: {excel_path}")
|
||||||
try:
|
try:
|
||||||
dify_compare_test.run(excel_path=excel_path, save_path=save_path, max_workers=10)
|
dify_compare_test.run(excel_path=excel_path, save_path=save_path, max_workers=5)
|
||||||
logging.info(f"文件处理完成: {excel_path}")
|
logging.info(f"文件处理完成: {excel_path}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"处理文件 {excel_path} 时出错: {e}")
|
logging.error(f"处理文件 {excel_path} 时出错: {e}")
|
||||||
|
|||||||
+13
-86
@@ -23,10 +23,11 @@ from rag2_0.tool.APIKeyManager import APIKeyManager
|
|||||||
|
|
||||||
class SiliconFlowEmbeddings(Embeddings):
|
class SiliconFlowEmbeddings(Embeddings):
|
||||||
"""SiliconFlow嵌入模型封装"""
|
"""SiliconFlow嵌入模型封装"""
|
||||||
def __init__(self, api_key: str, model: str = "bge-m3"):
|
def __init__(self, api_key: str, model: str = os.getenv("EMBEDDING_MODEL_NAME", "bge-m3")):
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
self.model = model
|
self.model = model
|
||||||
self.url = "http://10.1.16.39:9995/v1/embeddings"
|
base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995")
|
||||||
|
self.url = urljoin(base_url.rstrip('/') + '/', 'v1/embeddings')
|
||||||
self.headers = {
|
self.headers = {
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
@@ -71,82 +72,6 @@ class SiliconFlowEmbeddings(Embeddings):
|
|||||||
result = await self._embed_async([text])
|
result = await self._embed_async([text])
|
||||||
return result[0]
|
return result[0]
|
||||||
|
|
||||||
class SiliconFlowReRankerModel:
|
|
||||||
@staticmethod
|
|
||||||
def rerank(query: str, documents: List[str], top_k: int = 10) -> List[str]:
|
|
||||||
"""
|
|
||||||
使用硅流重排模型对文档进行重新排序
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户查询文本
|
|
||||||
documents: 需要重新排序的文档列表
|
|
||||||
top_k: 返回排序后的前k个文档
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
|
||||||
"""
|
|
||||||
url = "https://api.siliconflow.cn/v1/rerank"
|
|
||||||
payload = {
|
|
||||||
"model": "BAAI/bge-reranker-v2-m3",
|
|
||||||
"query": query,
|
|
||||||
"documents": documents,
|
|
||||||
"top_n": top_k,
|
|
||||||
"max_chunks_per_doc": 1024,
|
|
||||||
"overlap_tokens": 80,
|
|
||||||
"return_documents": True
|
|
||||||
}
|
|
||||||
api_key = APIKeyManager.get_api_key()
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
response = requests.post(url, json=payload, headers=headers, timeout=300)
|
|
||||||
response.raise_for_status()
|
|
||||||
results = response.json()
|
|
||||||
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
logging.error(f"重排序请求失败: {str(e)}", exc_info=True)
|
|
||||||
return []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def rerank_async(query: str, documents: List[str], top_k: int = 10) -> List[str]:
|
|
||||||
"""
|
|
||||||
使用硅流重排模型对文档进行异步重新排序
|
|
||||||
|
|
||||||
Args:
|
|
||||||
query: 用户查询文本
|
|
||||||
documents: 需要重新排序的文档列表
|
|
||||||
top_k: 返回排序后的前k个文档
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
|
||||||
"""
|
|
||||||
url = "https://api.siliconflow.cn/v1/rerank"
|
|
||||||
payload = {
|
|
||||||
"model": "BAAI/bge-reranker-v2-m3",
|
|
||||||
"query": query,
|
|
||||||
"documents": documents,
|
|
||||||
"top_n": top_k,
|
|
||||||
"max_chunks_per_doc": 1024,
|
|
||||||
"overlap_tokens": 80,
|
|
||||||
"return_documents": True
|
|
||||||
}
|
|
||||||
api_key = APIKeyManager.get_api_key()
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=300) as client:
|
|
||||||
response = await client.post(url, json=payload, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
results = response.json()
|
|
||||||
return [{"document": item["document"]["text"], "score": item["relevance_score"], "index": item["index"]} for item in results["results"]]
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logging.error(f"异步重排序请求失败: {str(e)}", exc_info=True)
|
|
||||||
return []
|
|
||||||
|
|
||||||
class XinferenceReRankerModel:
|
class XinferenceReRankerModel:
|
||||||
"""重排模型封装"""
|
"""重排模型封装"""
|
||||||
|
|
||||||
@@ -163,17 +88,18 @@ class XinferenceReRankerModel:
|
|||||||
Returns:
|
Returns:
|
||||||
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
||||||
"""
|
"""
|
||||||
url = "http://10.1.16.39:9995/v1/rerank"
|
|
||||||
|
|
||||||
|
base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995")
|
||||||
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": "bge-reranker-v2-m3"}
|
model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3")
|
||||||
|
rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank')
|
||||||
|
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name}
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
|
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = requests.post(url, json=params, headers=headers, timeout=300)
|
response = requests.post(rerank_url, json=params, headers=headers, timeout=300)
|
||||||
response.raise_for_status() # 检查响应状态
|
response.raise_for_status() # 检查响应状态
|
||||||
results = response.json()
|
results = response.json()
|
||||||
|
|
||||||
@@ -197,9 +123,10 @@ class XinferenceReRankerModel:
|
|||||||
Returns:
|
Returns:
|
||||||
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
List[dict]: 重排序后的文档列表,每个元素包含document内容、相关性分数和原始索引
|
||||||
"""
|
"""
|
||||||
url = "http://10.1.16.39:9995/v1/rerank"
|
base_url = os.getenv("XINFERENCE_URL", "http://10.1.16.39:9995")
|
||||||
|
rerank_url = urljoin(base_url.rstrip('/') + '/', 'v1/rerank')
|
||||||
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": "bge-reranker-v2-m3"}
|
model_name = os.getenv("RERANKER_MODEL_NAME", "bge-reranker-v2-m3")
|
||||||
|
params = {"documents": documents, "query": query, "top_n": top_k, "return_documents": True, "model": model_name}
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
|
"Authorization": "Bearer <token>", # 这里需要替换为实际的token
|
||||||
"Content-Type": "application/json"
|
"Content-Type": "application/json"
|
||||||
@@ -207,7 +134,7 @@ class XinferenceReRankerModel:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=300) as client:
|
async with httpx.AsyncClient(timeout=300) as client:
|
||||||
response = await client.post(url, json=params, headers=headers)
|
response = await client.post(rerank_url, json=params, headers=headers)
|
||||||
response.raise_for_status() # 检查响应状态
|
response.raise_for_status() # 检查响应状态
|
||||||
results = response.json()
|
results = response.json()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user