Files
QueryRewrite/rag2_0/dify/DifyCompareTest.py
T

267 lines
9.9 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import sys
import pandas as pd
# 使用线程池并发执行
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
import json
import re
from dotenv import load_dotenv
import logging
from datetime import datetime
import os
from langchain_core.output_parsers import JsonOutputParser
sys.path.append(os.getcwd())
from rag2_0.dify.dify_client import ChatClient
from rag2_0.tool.ModelTool import OpenAiLLM
from rag2_0.dify.dify_tool import PgSql, DifyTool
from rag2_0.dify.export_new_dify import DifyExporter
load_dotenv()
# 创建日志目录
log_dir = 'data/logs'
if not os.path.exists(log_dir):
os.makedirs(log_dir)
# 生成带时间戳的日志文件名
log_file = os.path.join(log_dir, f'dify_compare_{datetime.now().strftime("%Y%m%d")}.log')
import logging
# 配置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler(), # 输出到控制台
logging.FileHandler(log_file, encoding='utf-8') # 同时输出到文件
]
)
class DifyCompareTest:
def __init__(self):
# 词条与工单同时检索
self.both_wiki_worker_client = ChatClient(api_key=os.getenv("DIFY_APP_KEY"), base_url=os.getenv("DIFY_BSAE_URL"))
self.llm = OpenAiLLM(base_url=os.getenv("OPENAI_API_BASE"), model=os.getenv("MODEL_NAME"))
self.exporter = DifyExporter()
def llm_judge_answer(self, old_answer: str, now_answer: str):
user_prompt = f"""
请判断以下两个文本描述内容是否大致相同(内容主体等)
文本1
<text_one>
{old_answer}
</text_one>
=================
文本2
<text_two>
{now_answer}
</text_two>
输出格式(json格式输出):
{{
"is_same": true or false,
"reason": "文本1和文本2大致相同"
}}
"""
max_retries = 3
retry_count = 0
while retry_count < max_retries:
try:
response = self.llm.invoke(user_prompt=user_prompt, need_retry=False, response_format={"type": "json_object"})
response.content = response.content.strip()
clean_output = re.sub(r'<think>.*?</think>', '', response.content, flags=re.DOTALL)
result = JsonOutputParser().parse(clean_output)
return "回答基本相同" if result.get("is_same", False) else "回答基本不相同"
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logging.error(f"LLM判断过程在尝试 {max_retries} 次后仍然出错: {e}")
return ""
else:
# 可以添加短暂的等待时间,避免立即重试
import time
time.sleep(1) # 等待1秒后重试
def process_workflow(self, client, inputs, query, old_answer):
"""处理工作流调用"""
max_retries = 3
retry_count = 0
while retry_count < max_retries:
try:
response = client.create_chat_message(
inputs=inputs, query=query, user="AutoCodeRun", response_mode="blocking"
)
result = response.json()
answer = result.get('answer', "")
if len(answer) == 0:
raise Exception(f"回答为空: {result}")
# if old_answer:
# judge_result = self.llm_judge_answer(old_answer=old_answer, now_answer=answer)
# else:
# judge_result=""
judge_result=""
# 只取回答的前半部分
answer = answer.split("----------------------------------------")[0].strip()
message_id = result.get('message_id', "")
return answer, judge_result, message_id
except Exception as e:
retry_count += 1
if retry_count >= max_retries:
logging.error(f"词条与工单同时检索调用失败 (尝试 {max_retries} 次后): {e}")
return '', '', ''
else:
import time
time.sleep(10) # 等待1秒后重试
def get_wiki_list_by_msgid(self,msg_id):
if msg_id is None or pd.isna(msg_id):
return ""
msg_debug_info = self.exporter.dify_tool.get_message_debug_info_by_id(msg_id)
if not msg_debug_info:
return ""
wiki_list = self.exporter.get_wiki_list(msg_debug_info)
if len(wiki_list) == 0:
return ""
else:
return "\n".join(list(set(wiki_list)))
def process_single_row(self, index, row):
"""处理单行数据的方法"""
try:
query = row["提问"]
if "回答" in row:
old_answer = row["回答"]
else:
old_answer = ""
current_software = row["当前软件"]
inputs = {
"current_softname": current_software,
"user_name": "AutoCodeRun"
}
# 调用词条与工单同时检索工作流
answer, judge_result, message_id = self.process_workflow(
self.both_wiki_worker_client,
inputs,
query,
old_answer
)
# 构建结果
result_row = row.copy()
result_row["message_id"] = message_id
result_row["回答"] = answer
# result_row["词条与工单同时回答对比"] = judge_result
result_row["检索到的词条"] = self.get_wiki_list_by_msgid(message_id)
logging.info(f"成功处理第 {index + 1} 行数据")
return index, result_row
except Exception as e:
logging.error(f"处理第 {index + 1} 行数据时出错: {e}")
result_row = row.copy()
result_row["回答"] = ''
result_row["检索到的词条"] = ''
result_row["message_id"] = ''
return index, result_row
def run(self, excel_path, save_path, max_workers=3):
"""
运行对比测试
Args:
excel_path: Excel文件路径
save_path: 保存路径
max_workers: 最大并发线程数,默认为3
"""
try:
# 读取Excel文件
if not os.path.exists(excel_path):
logging.error(f"Excel文件不存在: {excel_path}")
return
df = pd.read_excel(excel_path)
logging.info(f"成功读取Excel文件: {excel_path}, 共 {len(df)} 行数据")
# 验证必要的列是否存在
required_columns = ["提问", "当前软件"]
missing_columns = [col for col in required_columns if col not in df.columns]
if missing_columns:
logging.error(f"Excel文件缺少必要的列: {missing_columns}")
return
# 创建保存目录
save_dir = os.path.dirname(save_path)
if save_dir and not os.path.exists(save_dir):
os.makedirs(save_dir)
# 使用线程池处理数据
results = {}
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务
future_to_index = {
executor.submit(self.process_single_row, index, row): index
for index, row in df.iterrows()
}
# 使用tqdm显示进度
with tqdm(total=len(future_to_index), desc="处理进度") as pbar:
for future in as_completed(future_to_index):
try:
index, result_row = future.result()
results[index] = result_row
pbar.update(1)
except Exception as e:
original_index = future_to_index[future]
logging.error(f"线程执行失败 (行{original_index + 1}): {e}")
# 添加失败的行
result_row = df.iloc[original_index].copy()
results[original_index] = result_row
pbar.update(1)
# 按原始顺序重新组织结果
rows_info = [results[i] for i in sorted(results.keys())]
# 保存结果
result_df = pd.DataFrame(rows_info)
result_df.to_excel(save_path, index=False)
logging.info(f"结果已保存到: {save_path}")
except Exception as e:
logging.error(f"运行过程中出现错误: {e}")
raise
if __name__ == "__main__":
try:
dify_compare_test = DifyCompareTest()
# 处理第一个文件
excel_files = [
# ("data/excel/5月.xlsx", "data/excel/5月问答对比.xlsx"),
("data/excel/7.30数据导出.xlsx", "data/excel/7.30数据导出_问答测试.xlsx")
]
for excel_path, save_path in excel_files:
logging.info(f"开始处理文件: {excel_path}")
try:
dify_compare_test.run(excel_path=excel_path, save_path=save_path, max_workers=3)
logging.info(f"文件处理完成: {excel_path}")
except Exception as e:
logging.error(f"处理文件 {excel_path} 时出错: {e}")
continue
logging.info("所有文件处理完成")
except Exception as e:
logging.error(f"程序执行出错: {e}")
sys.exit(1)