Files
QueryRewrite/rag2_0/dify/test_dify_chatapi.py
T

736 lines
29 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
from rag2_0.dify.dify_client import DifyClient
from rag2_0.dify.dify_tool import NewWorkflowChat, OldWorkFlowChat
import pandas as pd
# 使用线程池并发执行
from concurrent.futures import ThreadPoolExecutor, as_completed
from tqdm import tqdm
from rag2_0.dify.dify_tool import DifyTool
import json
from urllib.parse import unquote
from rag2_0.tool.WikijsTool import WikijsTool
from rag2_0.tool.html_to_md import convert_html_to_md
from rag2_0.tool.ModelTool import OpenAiLLM
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
from threading import Lock
import sys
import argparse
load_dotenv()
class ContentSource(BaseModel):
score:int = Field(description="相关性分数")
reason:str = Field(description="评分理由")
class DifyComparisonTester:
"""
Dify新旧流程对比测试类,用于比较两个不同流程的问答效果并进行评判
"""
def __init__(self, excel_path:str, baseurl:str, new_workflow_api_key:str,
old_workflow_api_key:str=None, wiki_excel_path:str=None,
output_path:str=None, max_workers:int=1, mode:str="both"):
"""
初始化对比测试器
Args:
excel_path: 包含问题的Excel文件路径
baseurl: Dify API的基础URL
new_workflow_api_key: 新流程的API密钥
old_workflow_api_key: 旧流程的API密钥,仅在mode="both"时需要
wiki_excel_path: Wiki Excel文件路径,用于获取标准答案
output_path: 输出Excel文件路径
max_workers: 最大工作线程数
mode: 测试模式,"new_only"表示仅测试新对话,"both"表示测试新老对话
"""
self.excel_path = excel_path
self.mode = mode
# 使用NewWorkflowChat和OldWorkFlowChat代替ChatClient
self.new_chat = NewWorkflowChat(api_key=new_workflow_api_key, base_url=baseurl)
if mode == "both" and old_workflow_api_key:
self.old_chat = OldWorkFlowChat(api_key=old_workflow_api_key, base_url=baseurl)
else:
self.old_chat = None
# 评判相关参数
self.output_path = output_path or os.path.join(os.path.dirname(self.excel_path), "dify问答_新流程结果.xlsx")
self.max_workers = max_workers
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
self.results_lock = Lock()
# 读取Wiki Excel文件
if wiki_excel_path and os.path.exists(wiki_excel_path):
self.wiki_excel = pd.read_excel(wiki_excel_path)
else:
self.wiki_excel = None
self.dify_tool = DifyTool()
def __del__(self):
"""
析构函数,在对象被销毁时自动关闭数据库连接。
确保在对象生命周期结束时释放数据库资源。
"""
self.dify_tool.close_connection()
def get_llm(self):
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model = os.getenv("LLM_MODEL_NAME")
return OpenAiLLM(api_key=api_key, base_url=base_url, model=model)
def find_wiki_link(self, query) -> str | None:
"""
根据查询找出对应的词条链接
Args:
query (str): 查询内容
Returns:
str: 对应的词条链接,如果没有找到则返回None
"""
# 确保query不为空
if not query or pd.isna(query):
return None
if self.wiki_excel is None:
return None
# 在"新提问"列中查找匹配的行
matched_rows = self.wiki_excel[self.wiki_excel['新提问'] == query]
# 如果找到了匹配的行,返回对应的词条链接
if not matched_rows.empty:
return matched_rows.iloc[0]['对应词条链接']
# 如果没有完全匹配,尝试部分匹配
# 去除软件名称部分(如果有)
query_parts = query.split(',', 1)
if len(query_parts) > 1:
clean_query = query_parts[1].strip()
# 在"提问"列中查找包含清理后查询的行
for idx, row in self.wiki_excel.iterrows():
if pd.notna(row['提问']) and clean_query in row['提问']:
return row['对应词条链接']
return None
def get_wiki_content(self, link) -> str:
"""
获取词条链接的内容
Args:
link (str): 词条链接
Returns:
str: 链接内容,如果获取失败则返回错误信息
"""
try:
if not link or pd.isna(link):
return "链接为空或无效"
# 移除域名部分,只保留路径
path = link.split('/', 3)[-1]
decoded_path = unquote(path)
path_parts = decoded_path.split('/')
doc_path = "/".join(path_parts[1:])
wiki_doc = WikijsTool.get_all_doc_by_path(path=doc_path, path_is_dir=False)
html_content = WikijsTool.query_doc_info(wiki_doc[0]["id"]).get('content')
if not html_content:
return "获取内容失败"
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
new_content = (html_content.replace("h6>", "h7>")
.replace("h5>", "h6>")
.replace("h4>", "h5>")
.replace("h3>", "h4>")
.replace("h2>", "h3>")
.replace("h1>", "h2>"))
# 将HTML内容转换为Markdown
markdown_content = convert_html_to_md(new_content, "", **options)
markdown_content = f"# {path_parts[-1]}\n\n{markdown_content}"
return markdown_content
except Exception as e:
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
def get_wiki_title(self, link) -> str | None:
"""
获取词条标题
Args:
link (str): 词条链接
Returns:
str: 词条标题,如果获取失败则返回None
"""
try:
if not link or pd.isna(link):
return None
# 移除域名部分,只保留路径
path = link.split('/', 3)[-1]
decoded_path = unquote(path)
path_parts = decoded_path.split('/')
return path_parts[-1]
except Exception as e:
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
def create_correctness_prompt(self, standard_answer: str, answer_to_check: str) -> str:
"""
创建用于评判答案正确性的prompt
Args:
standard_answer (str): 标准答案
answer_to_check (str): 需要检查的答案
Returns:
str: 格式化的prompt
"""
return f"""请作为一个专业的答案评判专家,评估以下回答与标准答案的匹配程度。
标准答案:
{standard_answer}
待评估的回答:
{answer_to_check}
请仔细分析两个答案的内容,并给出你的判断。只需要回答"正确"或"错误",不需要其他解释。
如果待评估的回答与标准答案在核心内容和关键信息(步骤)上一致,即使表达方式不同,也应判定为"正确"。
如果待评估的回答存在明显的错误信息或重要信息缺失,应判定为"错误"。
请严格按以下格式输出:【正确】或【错误】:"""
def judge_answer(self, standard_answer: str, answer: str) -> bool | None:
"""
调用LLM判断回答是否正确
Args:
standard_answer (str): 标准答案(来自Wiki
answer (str): 需评判的回答
Returns:
bool | None: 判断结果,True表示正确,False表示错误,None表示判断失败
"""
prompt = self.create_correctness_prompt(standard_answer, answer)
llm = self.get_llm()
try:
response = llm.invoke(user_prompt=prompt, need_retry=True)
return "正确" in response.content
except Exception as e:
return None
def judge_by_standard_answer(self, standard_answer: str, old_answer: str, new_answer: str) -> str | None:
"""
综合判断新旧回答的正确性
Args:
standard_answer (str): 标准答案(来自Wiki
old_answer (str): 旧流程的回答
new_answer (str): 新流程的回答
Returns:
str | None: 包含新旧回答判断结果的字符串,None表示判断失败
"""
old_result = self.judge_answer(standard_answer, old_answer)
new_result = self.judge_answer(standard_answer, new_answer)
if old_result is None or new_result is None:
return None
if new_result and old_result:
return "新旧答案均正确"
elif new_result and not old_result:
return "新答案正确"
elif not new_result and old_result:
return "旧答案正确"
else:
return "新旧答案均错误"
def judge_answer_diff(self, old_answer: str, new_answer: str) -> str | None:
"""
判断新旧回答是否存在较大差异
Args:
old_answer (str): 旧流程的回答
new_answer (str): 新流程的回答
Returns:
str | None: 差异判断结果,None表示判断失败
"""
prompt = f"""请判断以下两个回答是否存在较大差异:
旧回答: {old_answer}
新回答: {new_answer}
主要是主要步骤、主要信息、或者主要主体的差异
请仅回答"存在较大差异"或"差异较小"。"""
llm = self.get_llm()
try:
response = llm.invoke(user_prompt=prompt, need_retry=True)
return "缺乏标准答案无法判断准确性,但答案基本相同" if "差异较小" in response.content else "缺乏标准答案无法判断准确性,但答案差异较大"
except Exception as e:
return None
def calculate_score(self, query:str, content:str) -> int:
"""
使用LLM判断query与content之间的相关性分数
Args:
query (str): 用户问题
content (str): 检索内容
Returns:
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
"""
try:
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
【评分标准】
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
8-9分:高度相关,核心要素匹配但存在少量信息缺失
6-7分:部分相关,涉及相同主题但存在重要信息缺失
4-5分:弱相关,仅次要信息点匹配
1-3分:完全不相关或信息冲突
【评估维度】
1. 主题一致性:核心主题/意图的匹配程度
2. 内容覆盖度:是否涵盖query的关键要素
3. 信息准确性:是否存在矛盾/错误信息
4. 细节丰富度:是否提供query要求的详细信息
【输出格式】
{{
"score": 评分,
"reason": "简明扼要的评分理由(中文)"
}}
【示例】
query: "新冠疫苗的常见副作用"
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%"
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
现在评估:
query: "{query}"
content: "{content}"
"""
llm = self.get_llm()
response = llm.invoke(user_prompt=prompt, need_retry=True)
# 解析JSON响应
try:
parsed_output = self.content_source_parser.parse(response.content)
return parsed_output.score
except Exception as e:
return -1
except Exception as e:
return -1
def get_retrieve_info(self, query:str, outputs:dict) -> tuple:
"""
获取检索信息并计算分数
Args:
query (str): 用户问题
outputs (dict): 检索输出结果
Returns:
tuple: (检索内容列表, 最高分, 最低分, 平均分)
"""
max_score = 0
min_score = 10
total_score = 0
valid_scores = 0
retrieve_content = []
# 使用线程池并发计算分数
with ThreadPoolExecutor() as executor:
# 创建任务列表
future_to_content = {}
for result in outputs["result"]:
content = result["content"].strip()
future = executor.submit(self.calculate_score, query=query, content=content)
future_to_content[future] = content
# 收集结果
for future in as_completed(future_to_content):
content = future_to_content[future]
score = future.result()
content_title = content.split("\n")[0]
if score != -1:
max_score = max(max_score, score)
min_score = min(min_score, score)
total_score += score
valid_scores += 1
if content_title:
retrieve_content.append(content_title + f"--得分({score}分)")
avg_score = total_score / valid_scores if valid_scores > 0 else 0
return retrieve_content, max_score, min_score, avg_score
def get_new_workflow_info(self, query:str, new_message_id:str) -> dict:
"""
获取新流程的问题分类
Args:
query (str): 用户问题
new_message_id (str): 新流程的消息ID
Returns:
dict: 包含问题分类结果的字典
"""
try:
# 使用DifyTool直接获取消息信息
new_message_info = self.dify_tool.get_message_debug_info_by_id(message_id=new_message_id)
# 初始化变量
retrieve_title = []
retrieve_content = []
rewrite_query = ""
vertical_classification = ""
sub_classification = ""
slot_info = ""
# 解析工作流节点信息
for workflow_node in new_message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "知识检索结果后处理":
outputs = json.loads(workflow_node["outputs"])
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
retrieve_content = outputs["result"]
elif workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
llm_result_json = json.loads(workflow_node['inputs'])["llm_result"]
json_result = json.loads(llm_result_json)
vertical_classification = json_result['vertical_classification']
sub_classification = json_result['sub_classification']
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
except Exception as e:
return None
return {
"问题改写": rewrite_query,
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
"问题分类": f"{vertical_classification} - {sub_classification}",
"槽点信息": slot_info
}
def get_old_workflow_info(self, query:str, old_message_id:str) -> dict:
"""
获取旧流程的问题分类
Args:
query (str): 用户问题
old_message_id (str): 旧的流程的消息ID
Returns:
dict: 包含问题分类结果的字典
"""
try:
# 使用DifyTool直接获取消息信息
old_message_info = self.dify_tool.get_message_debug_info_by_id(message_id=old_message_id)
# 初始化变量
retrieve_title = []
retrieve_content = []
rewrite_query = ""
# 解析工作流节点信息
for workflow_node in old_message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "知识检索结果后处理":
outputs = json.loads(workflow_node["outputs"])
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
retrieve_content = outputs["result"]
elif workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
except Exception as e:
return None
return {
"问题改写": rewrite_query,
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
}
def get_retrieve_title_similarity(self, old_retrieve_content:list[dict], new_retrieve_content:list[dict]) -> str:
old_retrieve_content_list=[content["content"] for content in old_retrieve_content]
new_retrieve_content_list=[content["content"] for content in new_retrieve_content]
# 计算两个列表的交集
intersection = set(old_retrieve_content_list).intersection(set(new_retrieve_content_list))
# 准备详细的比较结果
intersection_count = len(intersection)
old_count = len(old_retrieve_content_list)
new_count = len(new_retrieve_content_list)
# 计算相似度 (Jaccard相似系数)
if old_count == 0 and new_count == 0:
similarity = 1.0 # 都为空时,认为完全相似
elif old_count == 0 or new_count == 0:
similarity = 0.0 # 一个为空时,认为完全不相似
else:
# 交集大小除以并集大小
union_count = len(set(old_retrieve_content_list).union(set(new_retrieve_content_list)))
similarity = intersection_count / union_count
similarity_percentage = round(similarity * 100, 2)
result = f"{similarity_percentage}%"
return result
def process_question(self, q:str) -> tuple:
"""
处理单个问题,获取新旧流程的回答
Args:
q: 问题内容
Returns:
tuple: (old_result, new_result) 包含旧流程和新流程的回答信息
"""
try:
# 如果是仅测试新流程模式
if self.mode == "new_only" or self.old_chat is None:
new_result = self.new_chat.process_question(q)
return None, new_result
else:
# 使用ThreadPoolExecutor并发执行新旧流程
with ThreadPoolExecutor(max_workers=2) as executor:
# 并发提交新旧流程的任务
future_new = executor.submit(self.new_chat.process_question, q)
future_old = executor.submit(self.old_chat.process_question, q)
# 获取结果
new_result = future_new.result()
old_result = future_old.result()
return old_result, new_result
except Exception as e:
print(f"处理问题 '{q}' 时发生错误: {str(e)}")
return None, None
def process_question_with_judge(self, q:str):
"""
处理单个问题,获取新旧流程的回答并进行评判
Args:
q: 问题内容
Returns:
dict: 包含问题、回答和评判结果的字典
"""
# 获取基本的问题和回答
future_old, future_new = self.process_question(q)
if future_new is None:
return None
# 如果是仅测试新流程模式
if self.mode == "new_only" or future_old is None:
query = future_new["问题"]
new_answer = future_new["新流程答案"]
# 获取词条链接和标准答案
wiki_url = self.find_wiki_link(query)
standard_answer = ""
answer_title = ""
try:
if wiki_url and not pd.isna(wiki_url):
standard_answer = self.get_wiki_content(wiki_url)
answer_title = self.get_wiki_title(wiki_url)
except Exception as e:
print(f"处理问题 '{query}' 获取标准答案时发生错误: {str(e)}")
# 判断答案正确性
judge_result = ""
if standard_answer:
# 调用LLM判断新答案是否正确
new_result = self.judge_answer(standard_answer, new_answer)
if new_result is not None:
judge_result = "正确" if new_result else "错误"
# 返回结果
return {
"问题": query,
"问题改写": future_new["新问题改写"],
"问题分类": future_new["新问题分类"],
"槽点信息": future_new["槽点信息"],
"新流程答案": new_answer,
"回答判断": judge_result,
"答案词条": answer_title if answer_title else "",
"检索词条": future_new["新检索词条"],
}
# 如果是测试新老流程模式
if future_old is None:
return None
query = future_old["问题"]
old_answer = future_old["旧流程答案"]
new_answer = future_new["新流程答案"]
# 获取词条链接和标准答案
wiki_url = self.find_wiki_link(query)
standard_answer = ""
answer_title = ""
try:
if wiki_url and not pd.isna(wiki_url):
standard_answer = self.get_wiki_content(wiki_url)
answer_title = self.get_wiki_title(wiki_url)
except Exception as e:
print(f"处理问题 '{query}' 获取标准答案时发生错误: {str(e)}")
# 判断答案正确性
if standard_answer:
judge_result = self.judge_by_standard_answer(standard_answer, old_answer, new_answer)
else:
judge_result = self.judge_answer_diff(old_answer, new_answer)
if judge_result is None:
judge_result = ""
# 返回结果
return {
"问题": query,
"新问题改写": future_new["新问题改写"],
"旧问题改写": future_old["旧问题改写"],
"新问题分类": future_new["新问题分类"],
"槽点信息": future_new["槽点信息"],
"新流程答案": new_answer,
"旧流程答案": old_answer,
"回答判断": judge_result,
# "词条检索相似度": retrieve_title_score,
"答案词条": answer_title if answer_title else "",
"新检索词条": future_new["新检索词条"],
"旧检索词条": future_old["旧检索词条"],
}
def run_comparison(self, with_judge=False):
"""
运行对比测试,处理所有问题并生成结果Excel
Args:
with_judge: 是否进行答案评判
Returns:
str: 输出Excel文件的路径
"""
# 读取Excel文件中的问题
df = pd.read_excel(self.excel_path)
questions=[]
for idx, row in df.iterrows():
if "回答中的软件名称" in row and "提问中的软件名称" in row:
if row['回答中的软件名称'] == "未知" and row['提问中的软件名称'] == "未知":
continue
if row['提问中的软件名称'] != "未知":
questions.append(row['提问'])
else:
questions.append(f"{row['回答中的软件名称']}, {row['提问']}")
else:
questions.append(row['提问'])
results = []
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
if not is_debug:
# 使用多线程并发处理问题
print("并发数量: ", self.max_workers)
print("问题数量: ", len(questions))
with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
# 创建进度条
with tqdm(total=len(questions), desc="处理问题进度") as pbar:
# 提交所有任务
futures = []
for q in questions:
future = executor.submit(self.process_question_with_judge, q)
futures.append(future)
# 处理结果
for future in as_completed(futures):
result = future.result()
if result is not None:
with self.results_lock:
results.append(result)
pbar.update(1)
else:
for q in questions:
result = self.process_question_with_judge(q)
print(json.dumps(result,ensure_ascii=False,indent=2))
if result is not None:
results.append(result)
# 生成输出Excel文件
out_path = self.output_path
df_results = pd.DataFrame(results)
# 使用ExcelWriter设置格式
with pd.ExcelWriter(out_path, engine='xlsxwriter') as writer:
df_results.to_excel(writer, index=False, sheet_name='Sheet1')
# 获取工作簿和工作表对象
workbook = writer.book
worksheet = writer.sheets['Sheet1']
# 设置列宽
for col_idx, col_name in enumerate(df_results.columns):
max_len = max(df_results[col_name].astype(str).map(len).max(), len(col_name))
worksheet.set_column(col_idx, col_idx, min(max_len + 2, 70))
return out_path
if __name__ == "__main__":
# 创建命令行参数解析器
default_excel_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", ".." ,"data/excel/历史提问数据(like)_提问明确.xlsx")
default_wiki_excel_path=os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", ".." ,"data/excel/部分提问_软件名称明确.xlsx")
parser = argparse.ArgumentParser(description='Dify对话测试工具')
parser.add_argument('--mode', type=str, choices=['new_only', 'both'], default='new_only',
help='测试模式: new_only表示仅测试新对话, both表示测试新老对话')
parser.add_argument('--excel_path', type=str,
default=default_excel_path,
help='包含问题的Excel文件路径')
parser.add_argument('--baseurl', type=str, default="http://172.20.0.145/v1",
help='Dify API的基础URL')
parser.add_argument('--new_api_key', type=str, default="app-qxsSybCs7ABiKlC1JabTYVn6",
help='新流程的API密钥')
parser.add_argument('--old_api_key', type=str, default="app-wUdkWJx5zeOvmvBUZizMoSw3",
help='旧流程的API密钥')
parser.add_argument('--wiki_excel_path', type=str,
default=default_wiki_excel_path,
help='Wiki Excel文件路径,用于获取标准答案')
parser.add_argument('--output_path', type=str, default=None,
help='输出Excel文件路径')
parser.add_argument('--max_workers', type=int, default=5,
help='最大工作线程数')
# 解析命令行参数
args = parser.parse_args()
# 检查Excel文件是否存在
if not os.path.exists(args.excel_path):
print(f"错误:Excel文件不存在: {args.excel_path}")
exit(1)
# 创建测试器并运行
tester = DifyComparisonTester(
excel_path=args.excel_path,
baseurl=args.baseurl,
new_workflow_api_key=args.new_api_key,
old_workflow_api_key=args.old_api_key if args.mode == "both" else None,
wiki_excel_path=args.wiki_excel_path,
output_path=args.output_path,
max_workers=args.max_workers,
mode=args.mode
)
# 运行对比测试(带评判)
output_file = tester.run_comparison(with_judge=True)
print(f"测试结果已保存至: {output_file}")