416 lines
16 KiB
Python
416 lines
16 KiB
Python
from dotenv import load_dotenv
|
||
import os
|
||
import json
|
||
import datetime
|
||
import pandas as pd
|
||
|
||
|
||
import sys
|
||
sys.path.append(os.getcwd())
|
||
from rag2_0.dify.dify_tool import PgSql, DifyTool
|
||
|
||
|
||
class DifyExporter:
|
||
"""
|
||
Dify数据导出工具,用于从Dify系统中导出对话和消息数据
|
||
|
||
支持按日期范围过滤消息,可以指定开始日期和结束日期
|
||
"""
|
||
def __init__(self, app_id=None, query_log_file=None, start_date=None, end_date=None):
|
||
"""
|
||
初始化DifyExporter实例
|
||
|
||
Args:
|
||
app_id: Dify应用ID,默认为None
|
||
query_log_file: 查询日志文件路径,默认为None
|
||
start_date: 开始日期时间,格式为YYYY-MM-DD HH,默认为None(不限制开始日期)
|
||
end_date: 结束日期时间,格式为YYYY-MM-DD HH,默认为None(不限制结束日期)
|
||
|
||
Note:
|
||
数据库中的时间是UTC+0时区,会自动转换为UTC+8时区进行过滤和显示
|
||
因此输入的start_date和end_date应该是UTC+8时区的时间
|
||
"""
|
||
# 设置默认值
|
||
self.app_id = app_id or "72d03c7d-8bea-42f9-9e8d-cdfb9480f372"
|
||
|
||
# 设置查询日志文件路径
|
||
self.query_log_dir = os.path.join(os.getcwd(), "data", "query_logs")
|
||
self.query_log_file = query_log_file or os.path.join(self.query_log_dir, "answer_type_logs.json")
|
||
|
||
# 设置日期过滤,转换为datetime对象
|
||
self.start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d %H") if start_date else None
|
||
self.end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d %H") if end_date else None
|
||
|
||
# 初始化工具类
|
||
self.dify_pgsql = PgSql()
|
||
self.dify_tool = DifyTool()
|
||
|
||
# 初始化数据存储
|
||
self.message_info_list = []
|
||
self.query_logs = {}
|
||
|
||
def load_query_logs(self,path):
|
||
"""
|
||
从文件加载查询日志
|
||
"""
|
||
try:
|
||
with open(path, 'r', encoding='utf-8') as f:
|
||
query_logs_list = json.load(f)
|
||
# 创建字典来存储每个查询的最新记录workflow_run_id
|
||
for record in query_logs_list:
|
||
workflow_run_id = record['workflow_run_id']
|
||
timestamp = record.get('timestamp')
|
||
# 如果查询不在字典中或者当前记录的时间戳更新,则更新字典
|
||
if workflow_run_id not in self.query_logs or (timestamp and self.query_logs.get(workflow_run_id, {}).get('timestamp') and
|
||
datetime.datetime.fromisoformat(timestamp) >
|
||
datetime.datetime.fromisoformat(self.query_logs[workflow_run_id]['timestamp'])):
|
||
self.query_logs[workflow_run_id] = record
|
||
return True
|
||
except Exception as e:
|
||
print(f"加载查询日志失败: {e}")
|
||
return False
|
||
|
||
def process_message_chain(self, messages):
|
||
"""
|
||
处理消息链,按照时间顺序重新组织消息
|
||
|
||
Args:
|
||
messages: 消息列表
|
||
|
||
Returns:
|
||
按时间顺序组织的消息列表
|
||
"""
|
||
message_chain = {}
|
||
for message in messages:
|
||
if message["parent_message_id"] in message_chain:
|
||
message_chain[message["parent_message_id"]].append(message)
|
||
else:
|
||
message_chain[message["parent_message_id"]] = [message]
|
||
|
||
message_chain_new = []
|
||
for message in message_chain.values():
|
||
if len(message) == 1:
|
||
message_chain_new.append(message[0])
|
||
else:
|
||
query_list = {}
|
||
for msg in message:
|
||
if msg['query'] in query_list:
|
||
if query_list[msg['query']]["created_at"] < msg['created_at']:
|
||
query_list[msg['query']] = msg
|
||
else:
|
||
query_list[msg['query']] = msg
|
||
for msg in query_list.values():
|
||
message_chain_new.append(msg)
|
||
return message_chain_new
|
||
|
||
def get_remark(self, msg_debug_info):
|
||
"""
|
||
获取备注
|
||
"""
|
||
intent_node_execution_info = [node_execution_info for node_execution_info in msg_debug_info['workflow_node_executions_info']
|
||
if node_execution_info["title"] == "意图识别结果解析"]
|
||
if len(intent_node_execution_info) == 0:
|
||
return ""
|
||
|
||
if intent_node_execution_info[0]["outputs"] is None:
|
||
return ""
|
||
intent_result = json.loads(intent_node_execution_info[0]["outputs"])
|
||
vertical_classification = intent_result.get("vertical_classification", "")
|
||
sub_classification = intent_result.get("sub_classification", "")
|
||
if sub_classification == "固定话术类":
|
||
return "使用固定话术"
|
||
|
||
worker_node_execution_info = [node_execution_info for node_execution_info in msg_debug_info['workflow_node_executions_info']
|
||
if node_execution_info["title"] == "检索工单数据"]
|
||
if len(worker_node_execution_info) != 0:
|
||
return "检索工单"
|
||
|
||
return ""
|
||
|
||
def get_node_info_by_title(self, workflow_node_executions_info:list, title:str) -> dict:
|
||
"""
|
||
获取指定标题的节点信息
|
||
"""
|
||
if workflow_node_executions_info is None:
|
||
return None
|
||
for node_execution in workflow_node_executions_info:
|
||
if node_execution["title"] == title:
|
||
return node_execution
|
||
|
||
return None
|
||
|
||
def get_wiki_list(self, msg_debug_info) -> list:
|
||
"""
|
||
获取检索到的词条列表
|
||
"""
|
||
wiki_list = []
|
||
if msg_debug_info['workflow_node_executions_info'] is None:
|
||
return []
|
||
node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "提取处理后的知识")
|
||
if node_execution is not None:
|
||
if node_execution["outputs"] is None:
|
||
return []
|
||
source_kno = json.loads(node_execution["outputs"])["source_kno"]
|
||
knowledge_list_metadata = json.loads(node_execution["outputs"])["knowledge_list_metadata"]
|
||
for knowledge in knowledge_list_metadata:
|
||
document_name = knowledge['metadata']['document_name']
|
||
wiki_list.append(document_name.split("/")[-1])
|
||
return wiki_list
|
||
|
||
lock_node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "软件锁知识")
|
||
if lock_node_execution is not None:
|
||
if lock_node_execution["outputs"] is None:
|
||
return []
|
||
source_kno = json.loads(lock_node_execution["outputs"])['json'][0]['retrieve_result']
|
||
for knowledge in source_kno:
|
||
document_name = knowledge['metadata']['document_name']
|
||
wiki_list.append(document_name.split("/")[-1])
|
||
|
||
wiki_list.append("锁信息查询")
|
||
wiki_list.append("软件锁注册、激活、查锁、试用锁延期")
|
||
return wiki_list
|
||
|
||
return []
|
||
|
||
|
||
def extract_message_info(self, message):
|
||
"""
|
||
从消息中提取信息
|
||
|
||
Note:
|
||
数据库中的created_at是UTC+0时间,会自动转换为UTC+8时间显示
|
||
|
||
Args:
|
||
message: 消息对象
|
||
|
||
Returns:
|
||
包含消息信息的字典
|
||
"""
|
||
msg_id = message["id"]
|
||
msg_inputs = message["inputs"]
|
||
user_name = msg_inputs.get("user_name", "")
|
||
current_softname = msg_inputs.get("current_softname", "")
|
||
msg_query = message["query"]
|
||
msg_answer = message["answer"]
|
||
msg_answer = msg_answer.split("----------------------------------------")[0]
|
||
# 将UTC+0时间转换为UTC+8时间
|
||
created_at_utc = message['created_at']
|
||
created_at_utc8 = created_at_utc + datetime.timedelta(hours=8)
|
||
created_at = created_at_utc8.strftime("%Y-%m-%d %H:%M")
|
||
|
||
msg_debug_info = self.dify_tool.get_message_debug_info_by_id(msg_id)
|
||
if not msg_debug_info:
|
||
return None
|
||
|
||
wiki_list = self.get_wiki_list(msg_debug_info)
|
||
|
||
if len(wiki_list) ==0:
|
||
wiki_list_str = self.get_remark(msg_debug_info)
|
||
else:
|
||
wiki_list = list(set(wiki_list))
|
||
wiki_list_str = "\n".join(wiki_list)
|
||
rating = self.dify_pgsql.get_message_rating(msg_id)
|
||
# 直接通过字典键获取query_type
|
||
workflow_run_id = message['workflow_run_id']
|
||
query_type = self.query_logs.get(workflow_run_id, {}).get('query_type', "")
|
||
|
||
return {
|
||
"msg_id": msg_id,
|
||
"提问": msg_query,
|
||
"当前软件": current_softname,
|
||
"回答": msg_answer,
|
||
"提问人": user_name,
|
||
"提问时间": created_at,
|
||
"评价": rating,
|
||
"问题分类": query_type,
|
||
"检索到的词条": wiki_list_str,
|
||
}
|
||
|
||
def process_conversations(self):
|
||
"""
|
||
处理会话数据,支持按日期范围过滤消息,精确到小时
|
||
|
||
Note:
|
||
数据库中的created_at是UTC+0时间,会自动转换为UTC+8时间进行过滤
|
||
|
||
Returns:
|
||
处理后的消息信息列表
|
||
"""
|
||
conversations = self.dify_pgsql.get_app_conversations(appid=self.app_id)
|
||
for conversation in conversations:
|
||
messages = self.dify_pgsql.get_conversation_messages(conversation_id=conversation['conversation_id'])
|
||
message_chain_new = self.process_message_chain(messages)
|
||
if len(message_chain_new) != len(messages):
|
||
print(f"过滤了{len(messages) - len(message_chain_new)}条消息,会话ID:{conversation['conversation_id']}")
|
||
|
||
for message in message_chain_new:
|
||
# 将UTC+0时间转换为UTC+8时间
|
||
created_at_utc = message['created_at']
|
||
created_at_utc8 = created_at_utc + datetime.timedelta(hours=8)
|
||
|
||
# 应用日期时间过滤
|
||
if self.start_date and created_at_utc8 < self.start_date:
|
||
continue
|
||
if self.end_date and created_at_utc8 > self.end_date:
|
||
continue
|
||
|
||
message_info = self.extract_message_info(message)
|
||
if message_info:
|
||
self.message_info_list.append(message_info)
|
||
|
||
return self.message_info_list
|
||
|
||
def save_to_excel(self, message_info_list, output_file):
|
||
"""
|
||
将消息信息列表保存到Excel文件
|
||
|
||
Args:
|
||
message_info_list: 消息信息列表
|
||
output_file: 输出文件路径
|
||
|
||
Returns:
|
||
输出文件路径
|
||
"""
|
||
# 创建DataFrame
|
||
df = pd.DataFrame(message_info_list)
|
||
|
||
# 设置列的顺序
|
||
columns_order = [
|
||
"msg_id","当前软件", "提问", "回答", "提问人", "提问时间",
|
||
"评价", "问题分类", "检索到的词条"
|
||
]
|
||
|
||
# 确保所有列都存在,如果不存在则添加空列
|
||
for col in columns_order:
|
||
if col not in df.columns:
|
||
df[col] = None
|
||
|
||
# 按指定顺序重排列
|
||
df = df[columns_order]
|
||
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
||
|
||
# 创建ExcelWriter对象,用于设置Excel样式
|
||
with pd.ExcelWriter(output_file, engine='openpyxl') as writer:
|
||
# 写入数据
|
||
df.to_excel(writer, index=False, sheet_name='Dify对话记录')
|
||
|
||
# 获取工作簿和工作表
|
||
workbook = writer.book
|
||
worksheet = writer.sheets['Dify对话记录']
|
||
|
||
# 设置行高(20磅 ≈ 26.67像素)
|
||
for row in worksheet.iter_rows():
|
||
worksheet.row_dimensions[row[0].row].height = 20
|
||
|
||
# 设置列宽
|
||
column_widths = {
|
||
"msg_id": 15,
|
||
"当前软件": 15,
|
||
"提问": 40,
|
||
"回答": 60,
|
||
"提问人": 15,
|
||
"提问时间": 15,
|
||
"评价": 10,
|
||
"问题分类": 20,
|
||
"检索到的词条": 40,
|
||
"备注": 40
|
||
}
|
||
|
||
# 应用列宽设置
|
||
for i, column in enumerate(columns_order):
|
||
col_letter = chr(65 + i) # A, B, C, ...
|
||
if i >= 26: # 超过Z的情况
|
||
col_letter = chr(64 + i // 26) + chr(65 + i % 26)
|
||
worksheet.column_dimensions[col_letter].width = column_widths[column]
|
||
|
||
print(f"结果已保存到 {output_file}")
|
||
|
||
return output_file
|
||
|
||
def export(self, output_file=None):
|
||
"""
|
||
执行导出流程
|
||
|
||
Args:
|
||
output_file: 输出文件路径,默认为None(自动生成文件名)
|
||
|
||
Returns:
|
||
处理后的消息信息列表
|
||
|
||
Note:
|
||
如果在初始化时指定了start_date或end_date,则只会导出指定日期时间范围内的消息
|
||
数据库中的时间是UTC+0时区,会自动转换为UTC+8时区进行过滤和显示
|
||
"""
|
||
# 加载查询日志
|
||
self.load_query_logs(self.query_log_file)
|
||
self.load_query_logs("data/query_logs/answer_type_logs_071409.json")
|
||
|
||
# 处理会话数据
|
||
self.process_conversations()
|
||
|
||
# 如果指定了输出文件,保存结果
|
||
if output_file or len(self.message_info_list) > 0:
|
||
# 如果没有指定输出文件,则使用默认文件名
|
||
if output_file is None:
|
||
timestamp = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
|
||
# 如果指定了日期范围,则在文件名中体现
|
||
date_suffix = ""
|
||
if self.start_date:
|
||
# 格式化日期对象为字符串
|
||
formatted_start = self.start_date.strftime("%Y-%m-%d_%H")
|
||
date_suffix += f"_from_{formatted_start}"
|
||
if self.end_date:
|
||
# 格式化日期对象为字符串
|
||
formatted_end = self.end_date.strftime("%Y-%m-%d_%H")
|
||
date_suffix += f"_to_{formatted_end}"
|
||
output_file = os.path.join(os.getcwd(), "data", "excel", f"dify_export{date_suffix}_{timestamp}.xlsx")
|
||
|
||
# 保存到Excel文件
|
||
self.save_to_excel(self.message_info_list, output_file)
|
||
|
||
return self.message_info_list
|
||
|
||
|
||
# 示例用法
|
||
if __name__ == "__main__":
|
||
import argparse
|
||
|
||
# 解析命令行参数
|
||
parser = argparse.ArgumentParser(description='Dify数据导出工具')
|
||
parser.add_argument('--output', '-o', type=str, default="data/excel/dify_export.xlsx",
|
||
help='输出Excel文件路径')
|
||
parser.add_argument('--app_id', '-a', type=str, default=None,
|
||
help='Dify应用ID')
|
||
parser.add_argument('--query_log_file', '-q', type=str, default="data/query_logs/answer_type_logs.json",
|
||
help='查询日志文件路径')
|
||
parser.add_argument('--start_date', '-s', type=str, default="2025-07-14 00",
|
||
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
|
||
parser.add_argument('--end_date', '-e', type=str, default=None,
|
||
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')
|
||
|
||
args = parser.parse_args()
|
||
|
||
load_dotenv()
|
||
# 设置环境变量
|
||
os.environ["DIFY_PG_HOST"] = "10.1.16.39"
|
||
os.environ["DIFY_PG_PORT"] = "5432"
|
||
os.environ["DIFY_PG_USER"] = "postgres"
|
||
os.environ["DIFY_PG_PASSWORD"] = "difyai123456"
|
||
os.environ["DIFY_PG_DATABASE"] = "dify"
|
||
|
||
# 创建导出器实例
|
||
exporter = DifyExporter(
|
||
app_id=args.app_id,
|
||
query_log_file=args.query_log_file,
|
||
start_date=args.start_date,
|
||
end_date=args.end_date
|
||
)
|
||
|
||
# 执行导出
|
||
results = exporter.export(output_file=args.output)
|
||
|
||
# 打印结果
|
||
print(f"导出了 {len(results)} 条消息信息")
|
||
|