280 lines
11 KiB
Python
Executable File
280 lines
11 KiB
Python
Executable File
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import sys
|
||
import pandas as pd
|
||
# 使用线程池并发执行
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from tqdm import tqdm
|
||
import json
|
||
import re
|
||
from dotenv import load_dotenv
|
||
import logging
|
||
from datetime import datetime
|
||
import os
|
||
from langchain_core.output_parsers import JsonOutputParser
|
||
|
||
sys.path.append(os.getcwd())
|
||
from rag2_0.dify.dify_client import ChatClient
|
||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||
from rag2_0.dify.dify_tool import DifyTool
|
||
from rag2_0.dify.export_new_dify import DifyExporter
|
||
load_dotenv()
|
||
# 创建日志目录
|
||
log_dir = 'data/logs'
|
||
if not os.path.exists(log_dir):
|
||
os.makedirs(log_dir)
|
||
|
||
# 生成带时间戳的日志文件名
|
||
log_file = os.path.join(log_dir, f'dify_compare_{datetime.now().strftime("%Y%m%d")}.log')
|
||
|
||
|
||
import logging
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.StreamHandler(), # 输出到控制台
|
||
logging.FileHandler(log_file, encoding='utf-8') # 同时输出到文件
|
||
]
|
||
)
|
||
|
||
class DifyCompareTest:
|
||
def __init__(self):
|
||
# 词条与工单同时检索
|
||
self.both_wiki_worker_client = ChatClient(api_key=os.getenv("DIFY_APP_KEY"), base_url=os.getenv("DIFY_BSAE_URL"))
|
||
self.llm = OpenAiLLM(base_url=os.getenv("OPENAI_API_BASE"), model=os.getenv("MODEL_NAME"))
|
||
self.exporter = DifyExporter()
|
||
|
||
def llm_judge_answer(self, old_answer: str, now_answer: str):
|
||
user_prompt = f"""
|
||
请判断以下两个文本描述内容是否大致相同(内容主体等)
|
||
文本1:
|
||
<text_one>
|
||
{old_answer}
|
||
</text_one>
|
||
=================
|
||
文本2:
|
||
<text_two>
|
||
{now_answer}
|
||
</text_two>
|
||
输出格式(json格式输出):
|
||
{{
|
||
"is_same": true or false,
|
||
"reason": "文本1和文本2大致相同"
|
||
}}
|
||
"""
|
||
|
||
max_retries = 3
|
||
retry_count = 0
|
||
if len(old_answer) == 0 or len(now_answer) == 0:
|
||
return "回答基本不相同"
|
||
while retry_count < max_retries:
|
||
try:
|
||
response = self.llm.invoke(user_prompt=user_prompt, response_format={"type": "json_object"})
|
||
response.content = response.content.strip()
|
||
clean_output = re.sub(r'<think>.*?</think>', '', response.content, flags=re.DOTALL)
|
||
result = JsonOutputParser().parse(clean_output)
|
||
return "回答基本相同" if result.get("is_same", False) else "回答基本不相同"
|
||
except Exception as e:
|
||
retry_count += 1
|
||
if retry_count >= max_retries:
|
||
logging.error(f"LLM判断过程在尝试 {max_retries} 次后仍然出错: {e}")
|
||
return ""
|
||
else:
|
||
# 可以添加短暂的等待时间,避免立即重试
|
||
import time
|
||
time.sleep(1) # 等待1秒后重试
|
||
|
||
|
||
def process_workflow(self, client, inputs, query, old_answer):
|
||
"""处理工作流调用"""
|
||
max_retries = 3
|
||
retry_count = 0
|
||
|
||
while retry_count < max_retries:
|
||
try:
|
||
response = client.create_chat_message(
|
||
inputs=inputs, query=query, user="AutoCodeRun", response_mode="blocking"
|
||
)
|
||
result = response.json()
|
||
answer = result.get('answer', "")
|
||
answer = answer.split("----------------------------------------")[0].strip()
|
||
if len(answer) == 0:
|
||
raise Exception(f"回答为空: {result}")
|
||
if isinstance(old_answer, str) and len(old_answer) > 0:
|
||
judge_result = self.llm_judge_answer(old_answer=old_answer, now_answer=answer)
|
||
else:
|
||
judge_result=""
|
||
# 只取回答的前半部分
|
||
message_id = result.get('message_id', "")
|
||
return answer, judge_result, message_id
|
||
except Exception as e:
|
||
retry_count += 1
|
||
if retry_count >= max_retries:
|
||
logging.error(f"词条与工单同时检索调用失败 (尝试 {max_retries} 次后): {e}")
|
||
return '', '', ''
|
||
else:
|
||
import time
|
||
time.sleep(10) # 等待1秒后重试
|
||
|
||
def get_wiki_list_by_msgid(self,msg_id):
|
||
try:
|
||
if msg_id is None or pd.isna(msg_id):
|
||
return ""
|
||
msg_debug_info = self.exporter.dify_tool.get_message_debug_info_by_id(msg_id)
|
||
if not msg_debug_info:
|
||
return ""
|
||
wiki_list = self.exporter.get_wiki_list(msg_debug_info)
|
||
if len(wiki_list) == 0:
|
||
return ""
|
||
else:
|
||
return "\n".join(list(set(wiki_list)))
|
||
except Exception as e:
|
||
logging.error(f"获取词条列表失败: {e}")
|
||
return ""
|
||
|
||
def process_single_row(self, index, row):
|
||
"""处理单行数据的方法"""
|
||
try:
|
||
query = row["提问"]
|
||
current_software = row["当前软件"]
|
||
if pd.isna(query) or len(query) == 0 or pd.isna(current_software) or len(current_software) == 0:
|
||
result_row = row.copy()
|
||
result_row["message_id"] = ''
|
||
result_row["本次回答"] = ''
|
||
result_row["回答对比"] = ''
|
||
result_row["检索到的词条"] = ''
|
||
return index, result_row
|
||
|
||
if "参考回答" in row:
|
||
old_answer = row["参考回答"]
|
||
else:
|
||
old_answer = ""
|
||
|
||
inputs = {
|
||
"current_softname": current_software,
|
||
"user_name": "AutoCodeRun"
|
||
}
|
||
|
||
# 调用词条与工单同时检索工作流
|
||
answer, judge_result, message_id = self.process_workflow(
|
||
self.both_wiki_worker_client,
|
||
inputs,
|
||
query,
|
||
old_answer
|
||
)
|
||
|
||
# 构建结果
|
||
result_row = row.copy()
|
||
result_row["message_id"] = message_id
|
||
result_row["本次回答"] = answer
|
||
result_row["回答对比"] = judge_result
|
||
result_row["检索到的词条"] = self.get_wiki_list_by_msgid(message_id)
|
||
logging.info(f"成功处理第 {index + 1} 行数据")
|
||
return index, result_row
|
||
|
||
except Exception as e:
|
||
logging.error(f"处理第 {index + 1} 行数据时出错: {e}")
|
||
result_row = row.copy()
|
||
result_row["message_id"] = ''
|
||
result_row["本次回答"] = ''
|
||
result_row["回答对比"] = ''
|
||
result_row["检索到的词条"] = ''
|
||
return index, result_row
|
||
|
||
|
||
def run(self, excel_path, save_path, max_workers=3):
|
||
"""
|
||
运行对比测试
|
||
|
||
Args:
|
||
excel_path: Excel文件路径
|
||
save_path: 保存路径
|
||
max_workers: 最大并发线程数,默认为3
|
||
"""
|
||
try:
|
||
# 读取Excel文件
|
||
if not os.path.exists(excel_path):
|
||
logging.error(f"Excel文件不存在: {excel_path}")
|
||
return
|
||
|
||
df = pd.read_excel(excel_path)
|
||
logging.info(f"成功读取Excel文件: {excel_path}, 共 {len(df)} 行数据")
|
||
|
||
# 验证必要的列是否存在
|
||
required_columns = ["提问", "当前软件"]
|
||
missing_columns = [col for col in required_columns if col not in df.columns]
|
||
if missing_columns:
|
||
logging.error(f"Excel文件缺少必要的列: {missing_columns}")
|
||
return
|
||
|
||
# 创建保存目录
|
||
save_dir = os.path.dirname(save_path)
|
||
if save_dir and not os.path.exists(save_dir):
|
||
os.makedirs(save_dir)
|
||
|
||
# 使用线程池处理数据
|
||
results = {}
|
||
|
||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||
# 提交所有任务
|
||
future_to_index = {
|
||
executor.submit(self.process_single_row, index, row): index
|
||
for index, row in df.iterrows()
|
||
}
|
||
|
||
# 使用tqdm显示进度
|
||
with tqdm(total=len(future_to_index), desc="处理进度") as pbar:
|
||
for future in as_completed(future_to_index):
|
||
try:
|
||
index, result_row = future.result()
|
||
results[index] = result_row
|
||
pbar.update(1)
|
||
except Exception as e:
|
||
original_index = future_to_index[future]
|
||
logging.error(f"线程执行失败 (行{original_index + 1}): {e}")
|
||
# 添加失败的行
|
||
result_row = df.iloc[original_index].copy()
|
||
results[original_index] = result_row
|
||
pbar.update(1)
|
||
|
||
# 按原始顺序重新组织结果
|
||
rows_info = [results[i] for i in sorted(results.keys())]
|
||
|
||
# 保存结果
|
||
result_df = pd.DataFrame(rows_info)
|
||
result_df.to_excel(save_path, index=False)
|
||
logging.info(f"结果已保存到: {save_path}")
|
||
|
||
except Exception as e:
|
||
logging.error(f"运行过程中出现错误: {e}")
|
||
raise
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
dify_compare_test = DifyCompareTest()
|
||
|
||
# 处理第一个文件
|
||
excel_files = [
|
||
("data/excel/第5轮-软件问题-点踩结果.xlsx", "data/excel/第5轮-软件问题-点踩结果_dify.xlsx"),
|
||
# ("data/excel/有知识的.xlsx", "data/excel/有知识的_问答测试.xlsx")
|
||
]
|
||
|
||
for excel_path, save_path in excel_files:
|
||
logging.info(f"开始处理文件: {excel_path}")
|
||
try:
|
||
dify_compare_test.run(excel_path=excel_path, save_path=save_path, max_workers=5)
|
||
logging.info(f"文件处理完成: {excel_path}")
|
||
except Exception as e:
|
||
logging.error(f"处理文件 {excel_path} 时出错: {e}")
|
||
continue
|
||
|
||
logging.info("所有文件处理完成")
|
||
|
||
except Exception as e:
|
||
logging.error(f"程序执行出错: {e}")
|
||
sys.exit(1)
|