762 lines
30 KiB
Python
Executable File
762 lines
30 KiB
Python
Executable File
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import sys
|
||
import argparse
|
||
from threading import Lock
|
||
import pandas as pd
|
||
# 使用线程池并发执行
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from tqdm import tqdm
|
||
import json
|
||
from urllib.parse import unquote
|
||
from dotenv import load_dotenv
|
||
from pydantic import BaseModel, Field
|
||
from langchain.output_parsers import PydanticOutputParser
|
||
|
||
sys.path.append(os.getcwd())
|
||
from rag2_0.dify.dify_client import DifyClient
|
||
from rag2_0.dify.dify_tool import NewWorkflowChat, OldWorkFlowChat
|
||
from rag2_0.tool.WikijsTool import WikijsTool
|
||
from rag2_0.tool.html_to_md import convert_html_to_md
|
||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||
from rag2_0.dify.dify_tool import DifyTool
|
||
|
||
load_dotenv()
|
||
|
||
import logging
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
handlers=[
|
||
logging.StreamHandler()
|
||
]
|
||
)
|
||
|
||
class ContentSource(BaseModel):
|
||
score:int = Field(description="相关性分数")
|
||
reason:str = Field(description="评分理由")
|
||
|
||
class DifyComparisonTester:
|
||
"""
|
||
Dify新旧流程对比测试类,用于比较两个不同流程的问答效果并进行评判
|
||
"""
|
||
def __init__(self, excel_path:str, baseurl:str, new_workflow_api_key:str,
|
||
old_workflow_api_key:str=None, output_path:str=None, max_workers:int=1, mode:str="both"):
|
||
"""
|
||
初始化对比测试器
|
||
|
||
Args:
|
||
excel_path: 包含问题的Excel文件路径
|
||
baseurl: Dify API的基础URL
|
||
new_workflow_api_key: 新流程的API密钥
|
||
old_workflow_api_key: 旧流程的API密钥,仅在mode="both"时需要
|
||
output_path: 输出Excel文件路径
|
||
max_workers: 最大工作线程数
|
||
mode: 测试模式,"new_only"表示仅测试新对话,"both"表示测试新老对话
|
||
"""
|
||
self.excel_path = excel_path
|
||
self.mode = mode
|
||
|
||
# 使用NewWorkflowChat和OldWorkFlowChat代替ChatClient
|
||
self.new_chat = NewWorkflowChat(api_key=new_workflow_api_key, base_url=baseurl)
|
||
if mode == "both" and old_workflow_api_key:
|
||
self.old_chat = OldWorkFlowChat(api_key=old_workflow_api_key, base_url=baseurl)
|
||
else:
|
||
self.old_chat = None
|
||
|
||
# 评判相关参数
|
||
self.output_path = output_path or os.path.join(os.path.dirname(self.excel_path), "dify问答_新流程结果.xlsx")
|
||
self.max_workers = max_workers
|
||
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
|
||
self.results_lock = Lock()
|
||
|
||
# 读取Wiki Excel文件
|
||
if excel_path and os.path.exists(excel_path):
|
||
self.wiki_excel = pd.read_excel(excel_path)
|
||
else:
|
||
self.wiki_excel = None
|
||
|
||
self.dify_tool = DifyTool()
|
||
|
||
def get_llm(self, **kwargs):
|
||
api_key = os.getenv("OPENAI_API_KEY")
|
||
base_url = os.getenv("OPENAI_API_BASE")
|
||
model = os.getenv("MODEL_NAME")
|
||
return OpenAiLLM(api_key=api_key, base_url=base_url, model=model, **kwargs)
|
||
|
||
def find_wiki_link(self, row) -> str | None:
|
||
"""
|
||
根据查询找出对应的词条链接
|
||
|
||
Args:
|
||
query (str): 查询内容
|
||
|
||
Returns:
|
||
str: 对应的词条链接,如果没有找到则返回None
|
||
"""
|
||
if self.wiki_excel is None:
|
||
return None
|
||
|
||
if "词条链接" in row:
|
||
return row["词条链接"]
|
||
return None
|
||
|
||
def get_wiki_content(self, link) -> str:
|
||
"""
|
||
获取词条链接的内容
|
||
|
||
Args:
|
||
link (str): 词条链接
|
||
|
||
Returns:
|
||
str: 链接内容,如果获取失败则返回错误信息
|
||
"""
|
||
try:
|
||
if not link or pd.isna(link):
|
||
return "链接为空或无效"
|
||
# 移除域名部分,只保留路径
|
||
path = link.split('/', 3)[-1]
|
||
decoded_path = unquote(path)
|
||
path_parts = decoded_path.split('/')
|
||
doc_path = "/".join(path_parts[1:])
|
||
wiki_doc = WikijsTool.get_all_doc_by_path(path=doc_path, path_is_dir=False)
|
||
html_content = WikijsTool.query_doc_info(wiki_doc[0]["id"]).get('content')
|
||
if not html_content:
|
||
return "获取内容失败"
|
||
|
||
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
|
||
new_content = (html_content.replace("h6>", "h7>")
|
||
.replace("h5>", "h6>")
|
||
.replace("h4>", "h5>")
|
||
.replace("h3>", "h4>")
|
||
.replace("h2>", "h3>")
|
||
.replace("h1>", "h2>"))
|
||
# 将HTML内容转换为Markdown
|
||
markdown_content = convert_html_to_md(new_content, "", **options)
|
||
markdown_content = f"# {path_parts[-1]}\n\n{markdown_content}"
|
||
return markdown_content
|
||
|
||
except Exception as e:
|
||
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
|
||
|
||
def get_wiki_title(self, link) -> str | None:
|
||
"""
|
||
获取词条标题
|
||
|
||
Args:
|
||
link (str): 词条链接
|
||
|
||
Returns:
|
||
str: 词条标题,如果获取失败则返回None
|
||
"""
|
||
try:
|
||
if not link or pd.isna(link):
|
||
return None
|
||
# 移除域名部分,只保留路径
|
||
path = link.split('/', 3)[-1]
|
||
decoded_path = unquote(path)
|
||
path_parts = decoded_path.split('/')
|
||
return path_parts[-1]
|
||
|
||
except Exception as e:
|
||
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
|
||
|
||
def create_correctness_prompt(self, standard_answer: str, answer_to_check: str) -> str:
|
||
"""
|
||
创建用于评判答案正确性的prompt
|
||
|
||
Args:
|
||
standard_answer (str): 标准答案
|
||
answer_to_check (str): 需要检查的答案
|
||
|
||
Returns:
|
||
str: 格式化的prompt
|
||
"""
|
||
return f"""请作为一个电力造价行业的专家,评估以下回答与标准答案的匹配程度。
|
||
|
||
标准答案:
|
||
{standard_answer}
|
||
|
||
待评估的回答:
|
||
{answer_to_check}
|
||
|
||
要求
|
||
1、分析待评估的回答与标准答案的匹配程度(包括内容、步骤、主体等)
|
||
2、如果待评估的回答与标准答案在核心内容和关键信息(步骤)上一致,即使表达方式不同,也应判定为"正确"。
|
||
3、如果待评估的回答存在明显的错误信息,应判定为"错误"。
|
||
4、请严格按json格式输出:
|
||
{{
|
||
"result": True or False,
|
||
"reason": "简明扼要的理由(中文)"
|
||
}}
|
||
字段说明:
|
||
result: True or False,待评估的回答是否正确
|
||
reason: 简明扼要的理由(中文)
|
||
"""
|
||
|
||
def judge_answer(self, standard_answer: str, answer: str) -> bool | None:
|
||
"""
|
||
调用LLM判断回答是否正确
|
||
|
||
Args:
|
||
standard_answer (str): 标准答案(来自Wiki)
|
||
answer (str): 需评判的回答
|
||
|
||
Returns:
|
||
bool | None: 判断结果,True表示正确,False表示错误,None表示判断失败
|
||
"""
|
||
|
||
prompt = self.create_correctness_prompt(standard_answer, answer)
|
||
llm = self.get_llm(response_format={"type": "json_object"})
|
||
|
||
max_retries = 3
|
||
retry_count = 0
|
||
|
||
while retry_count < max_retries:
|
||
try:
|
||
response = llm.invoke(user_prompt=prompt, need_retry=True)
|
||
response_json = json.loads(response.content)
|
||
return response_json["result"]
|
||
except Exception as e:
|
||
retry_count += 1
|
||
if retry_count >= max_retries:
|
||
logging.error(f"判断答案失败,已重试{max_retries}次: {str(e)}")
|
||
return False
|
||
# 指数退避策略,每次重试等待时间增加
|
||
import time
|
||
time.sleep(1 * (2 ** (retry_count - 1))) # 1秒, 2秒, 4秒...
|
||
|
||
def judge_by_standard_answer(self, standard_answer: str, old_answer: str, new_answer: str) -> str | None:
|
||
"""
|
||
综合判断新旧回答的正确性
|
||
|
||
Args:
|
||
standard_answer (str): 标准答案(来自Wiki)
|
||
old_answer (str): 旧流程的回答
|
||
new_answer (str): 新流程的回答
|
||
|
||
Returns:
|
||
str | None: 包含新旧回答判断结果的字符串,None表示判断失败
|
||
"""
|
||
old_result = self.judge_answer(standard_answer, old_answer)
|
||
new_result = self.judge_answer(standard_answer, new_answer)
|
||
if old_result is None or new_result is None:
|
||
return None
|
||
if new_result and old_result:
|
||
return "新旧答案均正确"
|
||
elif new_result and not old_result:
|
||
return "新答案正确"
|
||
elif not new_result and old_result:
|
||
return "旧答案正确"
|
||
else:
|
||
return "新旧答案均错误"
|
||
|
||
def judge_answer_diff(self, old_answer: str, new_answer: str) -> str | None:
|
||
"""
|
||
判断新旧回答是否存在较大差异
|
||
|
||
Args:
|
||
old_answer (str): 旧流程的回答
|
||
new_answer (str): 新流程的回答
|
||
|
||
Returns:
|
||
str | None: 差异判断结果,None表示判断失败
|
||
"""
|
||
|
||
prompt = f"""请判断以下两个回答是否存在较大差异:
|
||
|
||
旧回答: {old_answer}
|
||
|
||
新回答: {new_answer}
|
||
|
||
主要是主要步骤、主要信息、或者主要主体的差异
|
||
请仅回答"存在较大差异"或"差异较小"。"""
|
||
llm = self.get_llm()
|
||
try:
|
||
response = llm.invoke(user_prompt=prompt, need_retry=True)
|
||
return "缺乏标准答案无法判断准确性,但答案基本相同" if "差异较小" in response.content else "缺乏标准答案无法判断准确性,但答案差异较大"
|
||
except Exception as e:
|
||
return None
|
||
|
||
def calculate_score(self, query:str, content:str) -> int:
|
||
"""
|
||
使用LLM判断query与content之间的相关性分数
|
||
|
||
Args:
|
||
query (str): 用户问题
|
||
content (str): 检索内容
|
||
|
||
Returns:
|
||
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
|
||
"""
|
||
|
||
try:
|
||
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
|
||
|
||
【评分标准】
|
||
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
|
||
8-9分:高度相关,核心要素匹配但存在少量信息缺失
|
||
6-7分:部分相关,涉及相同主题但存在重要信息缺失
|
||
4-5分:弱相关,仅次要信息点匹配
|
||
1-3分:完全不相关或信息冲突
|
||
|
||
【评估维度】
|
||
1. 主题一致性:核心主题/意图的匹配程度
|
||
2. 内容覆盖度:是否涵盖query的关键要素
|
||
3. 信息准确性:是否存在矛盾/错误信息
|
||
4. 细节丰富度:是否提供query要求的详细信息
|
||
|
||
【输出格式】
|
||
{{
|
||
"score": 评分,
|
||
"reason": "简明扼要的评分理由(中文)"
|
||
}}
|
||
|
||
【示例】
|
||
query: "新冠疫苗的常见副作用"
|
||
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)"
|
||
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
|
||
|
||
现在评估:
|
||
query: "{query}"
|
||
content: "{content}"
|
||
"""
|
||
llm = self.get_llm()
|
||
response = llm.invoke(user_prompt=prompt, need_retry=True)
|
||
|
||
# 解析JSON响应
|
||
try:
|
||
parsed_output = self.content_source_parser.parse(response.content)
|
||
return parsed_output.score
|
||
except Exception as e:
|
||
return -1
|
||
except Exception as e:
|
||
return -1
|
||
|
||
def get_retrieve_info(self, query:str, outputs:dict) -> tuple:
|
||
"""
|
||
获取检索信息并计算分数
|
||
|
||
Args:
|
||
query (str): 用户问题
|
||
outputs (dict): 检索输出结果
|
||
|
||
Returns:
|
||
tuple: (检索内容列表, 最高分, 最低分, 平均分)
|
||
"""
|
||
max_score = 0
|
||
min_score = 10
|
||
total_score = 0
|
||
valid_scores = 0
|
||
retrieve_content = []
|
||
|
||
# 使用线程池并发计算分数
|
||
with ThreadPoolExecutor() as executor:
|
||
# 创建任务列表
|
||
future_to_content = {}
|
||
for result in outputs["result"]:
|
||
content = result["content"].strip()
|
||
future = executor.submit(self.calculate_score, query=query, content=content)
|
||
future_to_content[future] = content
|
||
|
||
# 收集结果
|
||
for future in as_completed(future_to_content):
|
||
content = future_to_content[future]
|
||
score = future.result()
|
||
content_title = content.split("\n")[0]
|
||
|
||
if score != -1:
|
||
max_score = max(max_score, score)
|
||
min_score = min(min_score, score)
|
||
total_score += score
|
||
valid_scores += 1
|
||
|
||
if content_title:
|
||
retrieve_content.append(content_title + f"--得分({score}分)")
|
||
|
||
avg_score = total_score / valid_scores if valid_scores > 0 else 0
|
||
return retrieve_content, max_score, min_score, avg_score
|
||
|
||
def get_new_workflow_info(self, query:str, new_message_id:str) -> dict:
|
||
"""
|
||
获取新流程的问题分类
|
||
|
||
Args:
|
||
query (str): 用户问题
|
||
new_message_id (str): 新流程的消息ID
|
||
|
||
Returns:
|
||
dict: 包含问题分类结果的字典
|
||
"""
|
||
try:
|
||
# 使用DifyTool直接获取消息信息
|
||
new_message_info = self.dify_tool.get_message_debug_info_by_id(message_id=new_message_id)
|
||
|
||
# 初始化变量
|
||
retrieve_title = []
|
||
retrieve_content = []
|
||
rewrite_query = ""
|
||
vertical_classification = ""
|
||
sub_classification = ""
|
||
slot_info = ""
|
||
|
||
# 解析工作流节点信息
|
||
for workflow_node in new_message_info["workflow_node_executions_info"]:
|
||
if workflow_node["title"] == "知识检索结果后处理":
|
||
outputs = json.loads(workflow_node["outputs"])
|
||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||
retrieve_content = outputs["result"]
|
||
elif workflow_node["title"] == "问题优化结果解析":
|
||
outputs = json.loads(workflow_node["outputs"])
|
||
rewrite_query = outputs["optimize_query"]
|
||
llm_result_json = json.loads(workflow_node['inputs'])["llm_result"]
|
||
json_result = json.loads(llm_result_json)
|
||
vertical_classification = json_result['vertical_classification']
|
||
sub_classification = json_result['sub_classification']
|
||
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
|
||
except Exception as e:
|
||
return None
|
||
|
||
return {
|
||
"问题改写": rewrite_query,
|
||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||
"问题分类": f"{vertical_classification} - {sub_classification}",
|
||
"槽点信息": slot_info
|
||
}
|
||
|
||
def get_old_workflow_info(self, query:str, old_message_id:str) -> dict:
|
||
"""
|
||
获取旧流程的问题分类
|
||
|
||
Args:
|
||
query (str): 用户问题
|
||
old_message_id (str): 旧的流程的消息ID
|
||
|
||
Returns:
|
||
dict: 包含问题分类结果的字典
|
||
"""
|
||
try:
|
||
# 使用DifyTool直接获取消息信息
|
||
old_message_info = self.dify_tool.get_message_debug_info_by_id(message_id=old_message_id)
|
||
|
||
# 初始化变量
|
||
retrieve_title = []
|
||
retrieve_content = []
|
||
rewrite_query = ""
|
||
|
||
# 解析工作流节点信息
|
||
for workflow_node in old_message_info["workflow_node_executions_info"]:
|
||
if workflow_node["title"] == "知识检索结果后处理":
|
||
outputs = json.loads(workflow_node["outputs"])
|
||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||
retrieve_content = outputs["result"]
|
||
elif workflow_node["title"] == "问题优化结果解析":
|
||
outputs = json.loads(workflow_node["outputs"])
|
||
rewrite_query = outputs["optimize_query"]
|
||
except Exception as e:
|
||
return None
|
||
|
||
return {
|
||
"问题改写": rewrite_query,
|
||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||
}
|
||
|
||
def get_retrieve_title_similarity(self, old_retrieve_content:list[dict], new_retrieve_content:list[dict]) -> str:
|
||
old_retrieve_content_list=[content["content"] for content in old_retrieve_content]
|
||
new_retrieve_content_list=[content["content"] for content in new_retrieve_content]
|
||
# 计算两个列表的交集
|
||
intersection = set(old_retrieve_content_list).intersection(set(new_retrieve_content_list))
|
||
|
||
# 准备详细的比较结果
|
||
intersection_count = len(intersection)
|
||
old_count = len(old_retrieve_content_list)
|
||
new_count = len(new_retrieve_content_list)
|
||
|
||
# 计算相似度 (Jaccard相似系数)
|
||
if old_count == 0 and new_count == 0:
|
||
similarity = 1.0 # 都为空时,认为完全相似
|
||
elif old_count == 0 or new_count == 0:
|
||
similarity = 0.0 # 一个为空时,认为完全不相似
|
||
else:
|
||
# 交集大小除以并集大小
|
||
union_count = len(set(old_retrieve_content_list).union(set(new_retrieve_content_list)))
|
||
similarity = intersection_count / union_count
|
||
|
||
similarity_percentage = round(similarity * 100, 2)
|
||
result = f"{similarity_percentage}%"
|
||
return result
|
||
|
||
def process_question(self, q:str) -> tuple:
|
||
"""
|
||
处理单个问题,获取新旧流程的回答
|
||
|
||
Args:
|
||
q: 问题内容
|
||
|
||
Returns:
|
||
tuple: (old_result, new_result) 包含旧流程和新流程的回答信息
|
||
"""
|
||
try:
|
||
# 如果是仅测试新流程模式
|
||
if self.mode == "new_only" or self.old_chat is None:
|
||
new_result = self.new_chat.process_question(q)
|
||
return None, new_result
|
||
else:
|
||
# 使用ThreadPoolExecutor并发执行新旧流程
|
||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||
# 并发提交新旧流程的任务
|
||
future_new = executor.submit(self.new_chat.process_question, q)
|
||
future_old = executor.submit(self.old_chat.process_question, q)
|
||
|
||
# 获取结果
|
||
new_result = future_new.result()
|
||
old_result = future_old.result()
|
||
|
||
return old_result, new_result
|
||
except Exception as e:
|
||
logging.error(f"处理问题 '{q}' 时发生错误: {str(e)}", exc_info=True)
|
||
return None, None
|
||
|
||
def process_question_with_judge(self, q:str, row):
|
||
"""
|
||
处理单个问题,获取新旧流程的回答并进行评判
|
||
|
||
Args:
|
||
q: 问题内容
|
||
|
||
Returns:
|
||
dict: 包含问题、回答和评判结果的字典
|
||
"""
|
||
try:
|
||
# 获取基本的问题和回答
|
||
future_old, future_new = self.process_question(q)
|
||
if future_new is None:
|
||
return None
|
||
|
||
# 如果是仅测试新流程模式
|
||
if self.mode == "new_only" or future_old is None:
|
||
query = future_new["问题"]
|
||
new_answer = future_new["新流程答案"]
|
||
|
||
# 获取词条链接和标准答案
|
||
wiki_url = self.find_wiki_link(row)
|
||
standard_answer = ""
|
||
answer_title = ""
|
||
|
||
try:
|
||
if wiki_url and not pd.isna(wiki_url):
|
||
standard_answer = self.get_wiki_content(wiki_url)
|
||
answer_title = self.get_wiki_title(wiki_url)
|
||
except Exception as e:
|
||
logging.error(f"处理问题 '{query}' 获取标准答案时发生错误: {str(e)}", exc_info=True)
|
||
|
||
# 判断答案正确性
|
||
judge_result = ""
|
||
if standard_answer:
|
||
# 调用LLM判断新答案是否正确
|
||
new_result = self.judge_answer(standard_answer, new_answer)
|
||
if new_result is not None:
|
||
judge_result = "正确" if new_result else "错误"
|
||
|
||
# 判断检索词条是否正确
|
||
retrieve_right = answer_title in future_new["新检索词条"]
|
||
retrieve_right_str = ("正确" if retrieve_right else "错误") if answer_title else ""
|
||
# 判断槽点是否缺失
|
||
slot_info = future_new["槽点信息"]
|
||
slot_info_data=None
|
||
if isinstance(slot_info, str):
|
||
slot_info_data = json.loads(slot_info)
|
||
else:
|
||
slot_info_data = slot_info
|
||
slot_missing = slot_info_data.get("missing_slots", {})
|
||
slot_missing_str = "完整" if len(slot_missing) == 0 else "缺失"
|
||
# 返回结果
|
||
return {
|
||
"问题": query,
|
||
"问题改写": future_new["新问题改写"],
|
||
"问题分类": future_new["新问题分类"],
|
||
"槽点信息": future_new["槽点信息"],
|
||
"槽点是否缺失": slot_missing_str,
|
||
"新流程答案": new_answer,
|
||
"回答是否正确": judge_result,
|
||
"检索是否正确": retrieve_right_str,
|
||
"答案词条": answer_title if answer_title else "",
|
||
"检索词条": future_new["新检索词条"],
|
||
}
|
||
|
||
# 如果是测试新老流程模式
|
||
if future_old is None:
|
||
return None
|
||
query = future_old["问题"]
|
||
old_answer = future_old["旧流程答案"]
|
||
new_answer = future_new["新流程答案"]
|
||
|
||
# 获取词条链接和标准答案
|
||
wiki_url = self.find_wiki_link(row)
|
||
standard_answer = ""
|
||
answer_title = ""
|
||
|
||
try:
|
||
if wiki_url and not pd.isna(wiki_url):
|
||
standard_answer = self.get_wiki_content(wiki_url)
|
||
answer_title = self.get_wiki_title(wiki_url)
|
||
except Exception as e:
|
||
logging.error(f"处理问题 '{query}' 获取标准答案时发生错误: {str(e)}", exc_info=True)
|
||
|
||
# 判断答案正确性
|
||
if standard_answer:
|
||
judge_result = self.judge_by_standard_answer(standard_answer, old_answer, new_answer)
|
||
else:
|
||
judge_result = self.judge_answer_diff(old_answer, new_answer)
|
||
|
||
if judge_result is None:
|
||
judge_result = ""
|
||
|
||
# 返回结果
|
||
return {
|
||
"问题": query,
|
||
"新问题改写": future_new["新问题改写"],
|
||
"旧问题改写": future_old["旧问题改写"],
|
||
"新问题分类": future_new["新问题分类"],
|
||
"槽点信息": future_new["槽点信息"],
|
||
"新流程答案": new_answer,
|
||
"旧流程答案": old_answer,
|
||
"回答判断": judge_result,
|
||
# "词条检索相似度": retrieve_title_score,
|
||
"答案词条": answer_title if answer_title else "",
|
||
"新检索词条": future_new["新检索词条"],
|
||
"旧检索词条": future_old["旧检索词条"],
|
||
}
|
||
except Exception as e:
|
||
logging.error(f"处理问题 '{q}' 时发生错误: {str(e)}", exc_info=True)
|
||
return None
|
||
|
||
def run_comparison(self, with_judge=False):
|
||
"""
|
||
运行对比测试,处理所有问题并生成结果Excel
|
||
|
||
Args:
|
||
with_judge: 是否进行答案评判
|
||
|
||
Returns:
|
||
str: 输出Excel文件的路径
|
||
"""
|
||
# 读取Excel文件中的问题
|
||
df = pd.read_excel(self.excel_path)
|
||
questions=[]
|
||
for idx, row in df.iterrows():
|
||
if "回答中的软件名称" in row and "提问中的软件名称" in row:
|
||
if row['回答中的软件名称'] == "未知" and row['提问中的软件名称'] == "未知":
|
||
continue
|
||
if row['提问中的软件名称'] != "未知":
|
||
questions.append((row['提问'],row))
|
||
else:
|
||
questions.append((f"{row['回答中的软件名称']}, {row['提问']}",row))
|
||
else:
|
||
questions.append((row['提问'], row))
|
||
|
||
results = []
|
||
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||
if not is_debug:
|
||
# 使用多线程并发处理问题
|
||
logging.info(f"并发数量: {self.max_workers}")
|
||
logging.info(f"问题数量: {len(questions)}")
|
||
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||
# 创建进度条
|
||
with tqdm(total=len(questions), desc="处理问题进度") as pbar:
|
||
# 提交所有任务
|
||
futures = []
|
||
for q, row in questions:
|
||
future = executor.submit(self.process_question_with_judge, q, row)
|
||
futures.append(future)
|
||
|
||
# 处理结果
|
||
for future in as_completed(futures):
|
||
result = future.result()
|
||
if result is not None:
|
||
with self.results_lock:
|
||
results.append(result)
|
||
pbar.update(1)
|
||
else:
|
||
for q, row in questions:
|
||
result = self.process_question_with_judge(q, row)
|
||
logging.info(json.dumps(result,ensure_ascii=False,indent=2))
|
||
if result is not None:
|
||
results.append(result)
|
||
|
||
# 生成输出Excel文件
|
||
out_path = self.output_path
|
||
df_results = pd.DataFrame(results)
|
||
|
||
# 使用ExcelWriter设置格式
|
||
with pd.ExcelWriter(out_path, engine='xlsxwriter') as writer:
|
||
df_results.to_excel(writer, index=False, sheet_name='Sheet1')
|
||
|
||
# 获取工作簿和工作表对象
|
||
workbook = writer.book
|
||
worksheet = writer.sheets['Sheet1']
|
||
|
||
# 设置列宽
|
||
for col_idx, col_name in enumerate(df_results.columns):
|
||
max_len = max(df_results[col_name].astype(str).map(len).max(), len(col_name))
|
||
worksheet.set_column(col_idx, col_idx, min(max_len + 2, 70))
|
||
|
||
return out_path
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 创建命令行参数解析器
|
||
os.environ["DIFY_BASEURL"] = "http://10.1.16.39/v1"
|
||
os.environ["DIFY_NEW_API_KEY"] = "app-rv6ie73Ufoa3nRYCMiJx3a8K"
|
||
os.environ["DIFY_OLD_API_KEY"] = "app-wUdkWJx5zeOvmvBUZizMoSw3"
|
||
|
||
os.environ["DIFY_PG_HOST"] = "10.1.16.39"
|
||
os.environ["DIFY_PG_PORT"] = "5432"
|
||
os.environ["DIFY_PG_USER"] = "postgres"
|
||
os.environ["DIFY_PG_PASSWORD"] = "difyai123456"
|
||
os.environ["DIFY_PG_DATABASE"] = "dify"
|
||
|
||
default_excel_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", ".." ,"data/excel/740条(dislike)_存在标准词条.xlsx")
|
||
parser = argparse.ArgumentParser(description='Dify对话测试工具')
|
||
parser.add_argument('--mode', type=str, choices=['new_only', 'both'], default='new_only',
|
||
help='测试模式: new_only表示仅测试新对话, both表示测试新老对话')
|
||
parser.add_argument('--excel_path', type=str,
|
||
default=default_excel_path,
|
||
help='包含问题的Excel文件路径')
|
||
parser.add_argument('--baseurl', type=str, default=os.getenv("DIFY_BASEURL"),
|
||
help='Dify API的基础URL')
|
||
parser.add_argument('--new_api_key', type=str, default=os.getenv("DIFY_NEW_API_KEY"),
|
||
help='新流程的API密钥')
|
||
parser.add_argument('--old_api_key', type=str, default=os.getenv("DIFY_OLD_API_KEY"),
|
||
help='旧流程的API密钥')
|
||
parser.add_argument('--output_path', type=str, default=None,
|
||
help='输出Excel文件路径')
|
||
parser.add_argument('--max_workers', type=int, default=5,
|
||
help='最大工作线程数')
|
||
|
||
# 解析命令行参数
|
||
args = parser.parse_args()
|
||
|
||
# 检查Excel文件是否存在
|
||
if not os.path.exists(args.excel_path):
|
||
logging.error(f"错误:Excel文件不存在: {args.excel_path}", exc_info=True)
|
||
exit(1)
|
||
|
||
# 创建测试器并运行
|
||
tester = DifyComparisonTester(
|
||
excel_path=args.excel_path,
|
||
baseurl=args.baseurl,
|
||
new_workflow_api_key=args.new_api_key,
|
||
old_workflow_api_key=args.old_api_key if args.mode == "both" else None,
|
||
output_path=args.output_path,
|
||
max_workers=args.max_workers,
|
||
mode=args.mode
|
||
)
|
||
|
||
# 运行对比测试(带评判)
|
||
output_file = tester.run_comparison(with_judge=True)
|
||
logging.info(f"测试结果已保存至: {output_file}")
|