优化DifyCompareTest类,添加DifyExporter实例以支持词条检索,更新DifyQueryRetrieval_api.py中的topk参数,增强DifyExporter类以从HTTP服务获取查询类型和点踩原因,简化构造函数,移除不必要的查询日志加载逻辑。

This commit is contained in:
2025-07-30 17:30:24 +08:00
parent 57369059eb
commit 728262cc65
3 changed files with 95 additions and 54 deletions
+23 -6
View File
@@ -18,7 +18,8 @@ from langchain_core.output_parsers import JsonOutputParser
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from rag2_0.dify.dify_client import ChatClient from rag2_0.dify.dify_client import ChatClient
from rag2_0.tool.ModelTool import OpenAiLLM from rag2_0.tool.ModelTool import OpenAiLLM
from rag2_0.dify.dify_tool import PgSql, DifyTool
from rag2_0.dify.export_new_dify import DifyExporter
load_dotenv() load_dotenv()
# 创建日志目录 # 创建日志目录
log_dir = 'data/logs' log_dir = 'data/logs'
@@ -45,6 +46,7 @@ class DifyCompareTest:
# 词条与工单同时检索 # 词条与工单同时检索
self.both_wiki_worker_client = ChatClient(api_key=os.getenv("DIFY_APP_KEY"), base_url=os.getenv("DIFY_BSAE_URL")) self.both_wiki_worker_client = ChatClient(api_key=os.getenv("DIFY_APP_KEY"), base_url=os.getenv("DIFY_BSAE_URL"))
self.llm = OpenAiLLM(base_url=os.getenv("OPENAI_API_BASE"), model=os.getenv("MODEL_NAME")) self.llm = OpenAiLLM(base_url=os.getenv("OPENAI_API_BASE"), model=os.getenv("MODEL_NAME"))
self.exporter = DifyExporter()
def llm_judge_answer(self, old_answer: str, now_answer: str): def llm_judge_answer(self, old_answer: str, now_answer: str):
user_prompt = f""" user_prompt = f"""
@@ -100,10 +102,11 @@ class DifyCompareTest:
answer = result.get('answer', "") answer = result.get('answer', "")
if len(answer) == 0: if len(answer) == 0:
raise Exception(f"回答为空: {result}") raise Exception(f"回答为空: {result}")
if old_answer: # if old_answer:
judge_result = self.llm_judge_answer(old_answer=old_answer, now_answer=answer) # judge_result = self.llm_judge_answer(old_answer=old_answer, now_answer=answer)
else: # else:
judge_result="" # judge_result=""
judge_result=""
# 只取回答的前半部分 # 只取回答的前半部分
answer = answer.split("----------------------------------------")[0].strip() answer = answer.split("----------------------------------------")[0].strip()
message_id = result.get('message_id', "") message_id = result.get('message_id', "")
@@ -117,6 +120,18 @@ class DifyCompareTest:
import time import time
time.sleep(10) # 等待1秒后重试 time.sleep(10) # 等待1秒后重试
def get_wiki_list_by_msgid(self,msg_id):
if msg_id is None or pd.isna(msg_id):
return ""
msg_debug_info = self.exporter.dify_tool.get_message_debug_info_by_id(msg_id)
if not msg_debug_info:
return ""
wiki_list = self.exporter.get_wiki_list(msg_debug_info)
if len(wiki_list) == 0:
return ""
else:
return "\n".join(list(set(wiki_list)))
def process_single_row(self, index, row): def process_single_row(self, index, row):
"""处理单行数据的方法""" """处理单行数据的方法"""
try: try:
@@ -145,6 +160,7 @@ class DifyCompareTest:
result_row["message_id"] = message_id result_row["message_id"] = message_id
result_row["回答"] = answer result_row["回答"] = answer
# result_row["词条与工单同时回答对比"] = judge_result # result_row["词条与工单同时回答对比"] = judge_result
result_row["检索到的词条"] = self.get_wiki_list_by_msgid(message_id)
logging.info(f"成功处理第 {index + 1} 行数据") logging.info(f"成功处理第 {index + 1} 行数据")
return index, result_row return index, result_row
@@ -152,6 +168,7 @@ class DifyCompareTest:
logging.error(f"处理第 {index + 1} 行数据时出错: {e}") logging.error(f"处理第 {index + 1} 行数据时出错: {e}")
result_row = row.copy() result_row = row.copy()
result_row["回答"] = '' result_row["回答"] = ''
result_row["检索到的词条"] = ''
result_row["message_id"] = '' result_row["message_id"] = ''
return index, result_row return index, result_row
@@ -230,7 +247,7 @@ if __name__ == "__main__":
# 处理第一个文件 # 处理第一个文件
excel_files = [ excel_files = [
# ("data/excel/5月.xlsx", "data/excel/5月问答对比.xlsx"), # ("data/excel/5月.xlsx", "data/excel/5月问答对比.xlsx"),
("data/excel/第四轮问题-Part2.xlsx", "data/excel/第四轮问题-Part2-问答测试.xlsx") ("data/excel/7.30数据导出.xlsx", "data/excel/7.30数据导出_问答测试.xlsx")
] ]
for excel_path, save_path in excel_files: for excel_path, save_path in excel_files:
+3 -2
View File
@@ -5,7 +5,7 @@ import os
from fastapi import FastAPI, HTTPException, Request from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field from pydantic import BaseModel, Field, ConfigDict
from typing import Dict, List, Any, Optional from typing import Dict, List, Any, Optional
import asyncio import asyncio
@@ -43,6 +43,7 @@ class RetrieveRequest(BaseModel):
query_list: str query_list: str
data_set_list: str data_set_list: str
query_expand_dict: dict | str = Field(default="{}") query_expand_dict: dict | str = Field(default="{}")
topk: int = Field(default=4)
# 创建FastAPI应用 # 创建FastAPI应用
app = FastAPI( app = FastAPI(
@@ -102,7 +103,7 @@ async def retrieve(request: RetrieveRequest):
query_list, query_list,
data_set_list, data_set_list,
query_expand_dict=query_expand_dict, query_expand_dict=query_expand_dict,
top_k=4 top_k=request.topk
) )
end_time = time.time() end_time = time.time()
+69 -46
View File
@@ -8,6 +8,7 @@ import pandas as pd
import sys import sys
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from rag2_0.dify.dify_tool import PgSql, DifyTool from rag2_0.dify.dify_tool import PgSql, DifyTool
import requests
class DifyExporter: class DifyExporter:
@@ -16,13 +17,12 @@ class DifyExporter:
支持按日期范围过滤消息,可以指定开始日期和结束日期 支持按日期范围过滤消息,可以指定开始日期和结束日期
""" """
def __init__(self, app_id=None, query_log_file=None, start_date=None, end_date=None): def __init__(self, app_id=None, start_date=None, end_date=None):
""" """
初始化DifyExporter实例 初始化DifyExporter实例
Args: Args:
app_id: Dify应用ID,默认为None app_id: Dify应用ID,默认为None
query_log_file: 查询日志文件路径,默认为None
start_date: 开始日期时间,格式为YYYY-MM-DD HH,默认为None(不限制开始日期) start_date: 开始日期时间,格式为YYYY-MM-DD HH,默认为None(不限制开始日期)
end_date: 结束日期时间,格式为YYYY-MM-DD HH,默认为None(不限制结束日期) end_date: 结束日期时间,格式为YYYY-MM-DD HH,默认为None(不限制结束日期)
@@ -33,10 +33,6 @@ class DifyExporter:
# 设置默认值 # 设置默认值
self.app_id = app_id or "72d03c7d-8bea-42f9-9e8d-cdfb9480f372" self.app_id = app_id or "72d03c7d-8bea-42f9-9e8d-cdfb9480f372"
# 设置查询日志文件路径
self.query_log_dir = os.path.join(os.getcwd(), "data", "query_logs")
self.query_log_file = query_log_file or os.path.join(self.query_log_dir, "answer_type_logs.json")
# 设置日期过滤,转换为datetime对象 # 设置日期过滤,转换为datetime对象
self.start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d %H") if start_date else None self.start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d %H") if start_date else None
self.end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d %H") if end_date else None self.end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d %H") if end_date else None
@@ -47,28 +43,9 @@ class DifyExporter:
# 初始化数据存储 # 初始化数据存储
self.message_info_list = [] self.message_info_list = []
self.query_logs = {}
# 设置AnswerType服务地址
def load_query_logs(self,path): self.answer_type_url = f"http://10.1.16.39:8003"
"""
从文件加载查询日志
"""
try:
with open(path, 'r', encoding='utf-8') as f:
query_logs_list = json.load(f)
# 创建字典来存储每个查询的最新记录workflow_run_id
for record in query_logs_list:
workflow_run_id = record['workflow_run_id']
timestamp = record.get('timestamp')
# 如果查询不在字典中或者当前记录的时间戳更新,则更新字典
if workflow_run_id not in self.query_logs or (timestamp and self.query_logs.get(workflow_run_id, {}).get('timestamp') and
datetime.datetime.fromisoformat(timestamp) >
datetime.datetime.fromisoformat(self.query_logs[workflow_run_id]['timestamp'])):
self.query_logs[workflow_run_id] = record
return True
except Exception as e:
print(f"加载查询日志失败: {e}")
return False
def process_message_chain(self, messages): def process_message_chain(self, messages):
""" """
@@ -150,18 +127,25 @@ class DifyExporter:
if node_execution is not None: if node_execution is not None:
if node_execution["outputs"] is None: if node_execution["outputs"] is None:
return [] return []
source_kno = json.loads(node_execution["outputs"])["source_kno"] outputs = json.loads(node_execution["outputs"])
knowledge_list_metadata = json.loads(node_execution["outputs"])["knowledge_list_metadata"] source_kno = outputs["source_kno"]
knowledge_list_metadata = outputs["knowledge_list_metadata"]
for knowledge in knowledge_list_metadata: for knowledge in knowledge_list_metadata:
document_name = knowledge['metadata']['document_name'] document_name = knowledge['metadata']['document_name']
wiki_list.append(document_name.split("/")[-1]) doc_metadata = knowledge['metadata']['doc_metadata']
if doc_metadata is None or doc_metadata.get("workorder_time", None) is not None:
wiki_list.append(document_name.split("/")[-1])
else:
dataset_name = knowledge['metadata']['dataset_name']
wiki_list.append(f"{dataset_name} - {document_name.split('/')[-1]}")
return wiki_list return wiki_list
lock_node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "软件锁知识") lock_node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "软件锁知识")
if lock_node_execution is not None: if lock_node_execution is not None:
if lock_node_execution["outputs"] is None: if lock_node_execution["outputs"] is None:
return [] return []
source_kno = json.loads(lock_node_execution["outputs"])['json'][0]['retrieve_result'] outputs = json.loads(lock_node_execution["outputs"])
source_kno = outputs['json'][0]['retrieve_result']
for knowledge in source_kno: for knowledge in source_kno:
document_name = knowledge['metadata']['document_name'] document_name = knowledge['metadata']['document_name']
wiki_list.append(document_name.split("/")[-1]) wiki_list.append(document_name.split("/")[-1])
@@ -172,6 +156,50 @@ class DifyExporter:
return [] return []
def get_query_type_from_service(self, workflow_run_id):
"""
从HTTP服务获取查询类型
Args:
workflow_run_id: 工作流运行ID
Returns:
查询类型字符串,如果获取失败则返回空字符串
"""
try:
url = f"{self.answer_type_url}/query_by_workflow_id?workflow_run_id={workflow_run_id}"
response = requests.get(url, timeout=2)
if response.status_code == 200:
data = response.json()
if data.get("data") and len(data["data"]) > 0:
return data["data"][0]["query_type"]
return ""
except Exception as e:
print(f"获取查询类型时出错: {e}")
return ""
def get_dislike_reason_from_service(self, workflow_run_id):
"""
从HTTP服务获取查询类型
Args:
workflow_run_id: 工作流运行ID
Returns:
查询类型字符串,如果获取失败则返回空字符串
"""
try:
url = f"{self.answer_type_url}/dislike_by_workflow_id?workflow_run_id={workflow_run_id}"
response = requests.get(url, timeout=2)
if response.status_code == 200:
data = response.json()
if data.get("data") and len(data["data"]) > 0:
return data["data"][0]["dislike_reason"]
return ""
except Exception as e:
print(f"获取查询类型时出错: {e}")
return ""
def extract_message_info(self, message): def extract_message_info(self, message):
""" """
@@ -210,10 +238,11 @@ class DifyExporter:
wiki_list = list(set(wiki_list)) wiki_list = list(set(wiki_list))
wiki_list_str = "\n".join(wiki_list) wiki_list_str = "\n".join(wiki_list)
rating = self.dify_pgsql.get_message_rating(msg_id) rating = self.dify_pgsql.get_message_rating(msg_id)
# 直接通过字典键获取query_type
# 从HTTP服务获取query_type
workflow_run_id = message['workflow_run_id'] workflow_run_id = message['workflow_run_id']
query_type = self.query_logs.get(workflow_run_id, {}).get('query_type', "") query_type = self.get_query_type_from_service(workflow_run_id)
dislike_reason = self.get_dislike_reason_from_service(workflow_run_id)
return { return {
"msg_id": msg_id, "msg_id": msg_id,
"提问": msg_query, "提问": msg_query,
@@ -224,6 +253,7 @@ class DifyExporter:
"评价": rating, "评价": rating,
"问题分类": query_type, "问题分类": query_type,
"检索到的词条": wiki_list_str, "检索到的词条": wiki_list_str,
"点踩原因": dislike_reason
} }
def process_conversations(self): def process_conversations(self):
@@ -277,7 +307,7 @@ class DifyExporter:
# 设置列的顺序 # 设置列的顺序
columns_order = [ columns_order = [
"msg_id","当前软件", "提问", "回答", "提问人", "提问时间", "msg_id","当前软件", "提问", "回答", "提问人", "提问时间",
"评价", "问题分类", "检索到的词条" "评价", "问题分类", "检索到的词条", "点踩原因"
] ]
# 确保所有列都存在,如果不存在则添加空列 # 确保所有列都存在,如果不存在则添加空列
@@ -315,7 +345,7 @@ class DifyExporter:
"评价": 10, "评价": 10,
"问题分类": 20, "问题分类": 20,
"检索到的词条": 40, "检索到的词条": 40,
"备注": 40 "点踩原因": 20
} }
# 应用列宽设置 # 应用列宽设置
@@ -343,10 +373,6 @@ class DifyExporter:
如果在初始化时指定了start_date或end_date,则只会导出指定日期时间范围内的消息 如果在初始化时指定了start_date或end_date,则只会导出指定日期时间范围内的消息
数据库中的时间是UTC+0时区,会自动转换为UTC+8时区进行过滤和显示 数据库中的时间是UTC+0时区,会自动转换为UTC+8时区进行过滤和显示
""" """
# 加载查询日志
self.load_query_logs(self.query_log_file)
self.load_query_logs("data/query_logs/answer_type_logs_071409.json")
# 处理会话数据 # 处理会话数据
self.process_conversations() self.process_conversations()
@@ -381,11 +407,9 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Dify数据导出工具') parser = argparse.ArgumentParser(description='Dify数据导出工具')
parser.add_argument('--output', '-o', type=str, default="data/excel/dify_export.xlsx", parser.add_argument('--output', '-o', type=str, default="data/excel/dify_export.xlsx",
help='输出Excel文件路径') help='输出Excel文件路径')
parser.add_argument('--app_id', '-a', type=str, default=None, parser.add_argument('--app_id', '-a', type=str, default="6218c4fd-bba3-4f5b-9fb5-61585d8eee51",
help='Dify应用ID') help='Dify应用ID')
parser.add_argument('--query_log_file', '-q', type=str, default="data/query_logs/answer_type_logs.json", parser.add_argument('--start_date', '-s', type=str, default="2025-07-30 00",
help='查询日志文件路径')
parser.add_argument('--start_date', '-s', type=str, default="2025-07-14 00",
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)') help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
parser.add_argument('--end_date', '-e', type=str, default=None, parser.add_argument('--end_date', '-e', type=str, default=None,
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)') help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')
@@ -403,7 +427,6 @@ if __name__ == "__main__":
# 创建导出器实例 # 创建导出器实例
exporter = DifyExporter( exporter = DifyExporter(
app_id=args.app_id, app_id=args.app_id,
query_log_file=args.query_log_file,
start_date=args.start_date, start_date=args.start_date,
end_date=args.end_date end_date=args.end_date
) )