优化DifyCompareTest类,添加DifyExporter实例以支持词条检索,更新DifyQueryRetrieval_api.py中的topk参数,增强DifyExporter类以从HTTP服务获取查询类型和点踩原因,简化构造函数,移除不必要的查询日志加载逻辑。
This commit is contained in:
@@ -18,7 +18,8 @@ from langchain_core.output_parsers import JsonOutputParser
|
|||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from rag2_0.dify.dify_client import ChatClient
|
from rag2_0.dify.dify_client import ChatClient
|
||||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||||
|
from rag2_0.dify.dify_tool import PgSql, DifyTool
|
||||||
|
from rag2_0.dify.export_new_dify import DifyExporter
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
# 创建日志目录
|
# 创建日志目录
|
||||||
log_dir = 'data/logs'
|
log_dir = 'data/logs'
|
||||||
@@ -45,6 +46,7 @@ class DifyCompareTest:
|
|||||||
# 词条与工单同时检索
|
# 词条与工单同时检索
|
||||||
self.both_wiki_worker_client = ChatClient(api_key=os.getenv("DIFY_APP_KEY"), base_url=os.getenv("DIFY_BSAE_URL"))
|
self.both_wiki_worker_client = ChatClient(api_key=os.getenv("DIFY_APP_KEY"), base_url=os.getenv("DIFY_BSAE_URL"))
|
||||||
self.llm = OpenAiLLM(base_url=os.getenv("OPENAI_API_BASE"), model=os.getenv("MODEL_NAME"))
|
self.llm = OpenAiLLM(base_url=os.getenv("OPENAI_API_BASE"), model=os.getenv("MODEL_NAME"))
|
||||||
|
self.exporter = DifyExporter()
|
||||||
|
|
||||||
def llm_judge_answer(self, old_answer: str, now_answer: str):
|
def llm_judge_answer(self, old_answer: str, now_answer: str):
|
||||||
user_prompt = f"""
|
user_prompt = f"""
|
||||||
@@ -100,10 +102,11 @@ class DifyCompareTest:
|
|||||||
answer = result.get('answer', "")
|
answer = result.get('answer', "")
|
||||||
if len(answer) == 0:
|
if len(answer) == 0:
|
||||||
raise Exception(f"回答为空: {result}")
|
raise Exception(f"回答为空: {result}")
|
||||||
if old_answer:
|
# if old_answer:
|
||||||
judge_result = self.llm_judge_answer(old_answer=old_answer, now_answer=answer)
|
# judge_result = self.llm_judge_answer(old_answer=old_answer, now_answer=answer)
|
||||||
else:
|
# else:
|
||||||
judge_result=""
|
# judge_result=""
|
||||||
|
judge_result=""
|
||||||
# 只取回答的前半部分
|
# 只取回答的前半部分
|
||||||
answer = answer.split("----------------------------------------")[0].strip()
|
answer = answer.split("----------------------------------------")[0].strip()
|
||||||
message_id = result.get('message_id', "")
|
message_id = result.get('message_id', "")
|
||||||
@@ -117,6 +120,18 @@ class DifyCompareTest:
|
|||||||
import time
|
import time
|
||||||
time.sleep(10) # 等待1秒后重试
|
time.sleep(10) # 等待1秒后重试
|
||||||
|
|
||||||
|
def get_wiki_list_by_msgid(self,msg_id):
|
||||||
|
if msg_id is None or pd.isna(msg_id):
|
||||||
|
return ""
|
||||||
|
msg_debug_info = self.exporter.dify_tool.get_message_debug_info_by_id(msg_id)
|
||||||
|
if not msg_debug_info:
|
||||||
|
return ""
|
||||||
|
wiki_list = self.exporter.get_wiki_list(msg_debug_info)
|
||||||
|
if len(wiki_list) == 0:
|
||||||
|
return ""
|
||||||
|
else:
|
||||||
|
return "\n".join(list(set(wiki_list)))
|
||||||
|
|
||||||
def process_single_row(self, index, row):
|
def process_single_row(self, index, row):
|
||||||
"""处理单行数据的方法"""
|
"""处理单行数据的方法"""
|
||||||
try:
|
try:
|
||||||
@@ -145,6 +160,7 @@ class DifyCompareTest:
|
|||||||
result_row["message_id"] = message_id
|
result_row["message_id"] = message_id
|
||||||
result_row["回答"] = answer
|
result_row["回答"] = answer
|
||||||
# result_row["词条与工单同时回答对比"] = judge_result
|
# result_row["词条与工单同时回答对比"] = judge_result
|
||||||
|
result_row["检索到的词条"] = self.get_wiki_list_by_msgid(message_id)
|
||||||
logging.info(f"成功处理第 {index + 1} 行数据")
|
logging.info(f"成功处理第 {index + 1} 行数据")
|
||||||
return index, result_row
|
return index, result_row
|
||||||
|
|
||||||
@@ -152,6 +168,7 @@ class DifyCompareTest:
|
|||||||
logging.error(f"处理第 {index + 1} 行数据时出错: {e}")
|
logging.error(f"处理第 {index + 1} 行数据时出错: {e}")
|
||||||
result_row = row.copy()
|
result_row = row.copy()
|
||||||
result_row["回答"] = ''
|
result_row["回答"] = ''
|
||||||
|
result_row["检索到的词条"] = ''
|
||||||
result_row["message_id"] = ''
|
result_row["message_id"] = ''
|
||||||
return index, result_row
|
return index, result_row
|
||||||
|
|
||||||
@@ -230,7 +247,7 @@ if __name__ == "__main__":
|
|||||||
# 处理第一个文件
|
# 处理第一个文件
|
||||||
excel_files = [
|
excel_files = [
|
||||||
# ("data/excel/5月.xlsx", "data/excel/5月问答对比.xlsx"),
|
# ("data/excel/5月.xlsx", "data/excel/5月问答对比.xlsx"),
|
||||||
("data/excel/第四轮问题-Part2.xlsx", "data/excel/第四轮问题-Part2-问答测试.xlsx")
|
("data/excel/7.30数据导出.xlsx", "data/excel/7.30数据导出_问答测试.xlsx")
|
||||||
]
|
]
|
||||||
|
|
||||||
for excel_path, save_path in excel_files:
|
for excel_path, save_path in excel_files:
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import os
|
|||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field, ConfigDict
|
||||||
from typing import Dict, List, Any, Optional
|
from typing import Dict, List, Any, Optional
|
||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
@@ -43,6 +43,7 @@ class RetrieveRequest(BaseModel):
|
|||||||
query_list: str
|
query_list: str
|
||||||
data_set_list: str
|
data_set_list: str
|
||||||
query_expand_dict: dict | str = Field(default="{}")
|
query_expand_dict: dict | str = Field(default="{}")
|
||||||
|
topk: int = Field(default=4)
|
||||||
|
|
||||||
# 创建FastAPI应用
|
# 创建FastAPI应用
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
@@ -102,7 +103,7 @@ async def retrieve(request: RetrieveRequest):
|
|||||||
query_list,
|
query_list,
|
||||||
data_set_list,
|
data_set_list,
|
||||||
query_expand_dict=query_expand_dict,
|
query_expand_dict=query_expand_dict,
|
||||||
top_k=4
|
top_k=request.topk
|
||||||
)
|
)
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
|
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import pandas as pd
|
|||||||
import sys
|
import sys
|
||||||
sys.path.append(os.getcwd())
|
sys.path.append(os.getcwd())
|
||||||
from rag2_0.dify.dify_tool import PgSql, DifyTool
|
from rag2_0.dify.dify_tool import PgSql, DifyTool
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
class DifyExporter:
|
class DifyExporter:
|
||||||
@@ -16,13 +17,12 @@ class DifyExporter:
|
|||||||
|
|
||||||
支持按日期范围过滤消息,可以指定开始日期和结束日期
|
支持按日期范围过滤消息,可以指定开始日期和结束日期
|
||||||
"""
|
"""
|
||||||
def __init__(self, app_id=None, query_log_file=None, start_date=None, end_date=None):
|
def __init__(self, app_id=None, start_date=None, end_date=None):
|
||||||
"""
|
"""
|
||||||
初始化DifyExporter实例
|
初始化DifyExporter实例
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app_id: Dify应用ID,默认为None
|
app_id: Dify应用ID,默认为None
|
||||||
query_log_file: 查询日志文件路径,默认为None
|
|
||||||
start_date: 开始日期时间,格式为YYYY-MM-DD HH,默认为None(不限制开始日期)
|
start_date: 开始日期时间,格式为YYYY-MM-DD HH,默认为None(不限制开始日期)
|
||||||
end_date: 结束日期时间,格式为YYYY-MM-DD HH,默认为None(不限制结束日期)
|
end_date: 结束日期时间,格式为YYYY-MM-DD HH,默认为None(不限制结束日期)
|
||||||
|
|
||||||
@@ -33,10 +33,6 @@ class DifyExporter:
|
|||||||
# 设置默认值
|
# 设置默认值
|
||||||
self.app_id = app_id or "72d03c7d-8bea-42f9-9e8d-cdfb9480f372"
|
self.app_id = app_id or "72d03c7d-8bea-42f9-9e8d-cdfb9480f372"
|
||||||
|
|
||||||
# 设置查询日志文件路径
|
|
||||||
self.query_log_dir = os.path.join(os.getcwd(), "data", "query_logs")
|
|
||||||
self.query_log_file = query_log_file or os.path.join(self.query_log_dir, "answer_type_logs.json")
|
|
||||||
|
|
||||||
# 设置日期过滤,转换为datetime对象
|
# 设置日期过滤,转换为datetime对象
|
||||||
self.start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d %H") if start_date else None
|
self.start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d %H") if start_date else None
|
||||||
self.end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d %H") if end_date else None
|
self.end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d %H") if end_date else None
|
||||||
@@ -47,28 +43,9 @@ class DifyExporter:
|
|||||||
|
|
||||||
# 初始化数据存储
|
# 初始化数据存储
|
||||||
self.message_info_list = []
|
self.message_info_list = []
|
||||||
self.query_logs = {}
|
|
||||||
|
# 设置AnswerType服务地址
|
||||||
def load_query_logs(self,path):
|
self.answer_type_url = f"http://10.1.16.39:8003"
|
||||||
"""
|
|
||||||
从文件加载查询日志
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
with open(path, 'r', encoding='utf-8') as f:
|
|
||||||
query_logs_list = json.load(f)
|
|
||||||
# 创建字典来存储每个查询的最新记录workflow_run_id
|
|
||||||
for record in query_logs_list:
|
|
||||||
workflow_run_id = record['workflow_run_id']
|
|
||||||
timestamp = record.get('timestamp')
|
|
||||||
# 如果查询不在字典中或者当前记录的时间戳更新,则更新字典
|
|
||||||
if workflow_run_id not in self.query_logs or (timestamp and self.query_logs.get(workflow_run_id, {}).get('timestamp') and
|
|
||||||
datetime.datetime.fromisoformat(timestamp) >
|
|
||||||
datetime.datetime.fromisoformat(self.query_logs[workflow_run_id]['timestamp'])):
|
|
||||||
self.query_logs[workflow_run_id] = record
|
|
||||||
return True
|
|
||||||
except Exception as e:
|
|
||||||
print(f"加载查询日志失败: {e}")
|
|
||||||
return False
|
|
||||||
|
|
||||||
def process_message_chain(self, messages):
|
def process_message_chain(self, messages):
|
||||||
"""
|
"""
|
||||||
@@ -150,18 +127,25 @@ class DifyExporter:
|
|||||||
if node_execution is not None:
|
if node_execution is not None:
|
||||||
if node_execution["outputs"] is None:
|
if node_execution["outputs"] is None:
|
||||||
return []
|
return []
|
||||||
source_kno = json.loads(node_execution["outputs"])["source_kno"]
|
outputs = json.loads(node_execution["outputs"])
|
||||||
knowledge_list_metadata = json.loads(node_execution["outputs"])["knowledge_list_metadata"]
|
source_kno = outputs["source_kno"]
|
||||||
|
knowledge_list_metadata = outputs["knowledge_list_metadata"]
|
||||||
for knowledge in knowledge_list_metadata:
|
for knowledge in knowledge_list_metadata:
|
||||||
document_name = knowledge['metadata']['document_name']
|
document_name = knowledge['metadata']['document_name']
|
||||||
wiki_list.append(document_name.split("/")[-1])
|
doc_metadata = knowledge['metadata']['doc_metadata']
|
||||||
|
if doc_metadata is None or doc_metadata.get("workorder_time", None) is not None:
|
||||||
|
wiki_list.append(document_name.split("/")[-1])
|
||||||
|
else:
|
||||||
|
dataset_name = knowledge['metadata']['dataset_name']
|
||||||
|
wiki_list.append(f"{dataset_name} - {document_name.split('/')[-1]}")
|
||||||
return wiki_list
|
return wiki_list
|
||||||
|
|
||||||
lock_node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "软件锁知识")
|
lock_node_execution = self.get_node_info_by_title(msg_debug_info['workflow_node_executions_info'], "软件锁知识")
|
||||||
if lock_node_execution is not None:
|
if lock_node_execution is not None:
|
||||||
if lock_node_execution["outputs"] is None:
|
if lock_node_execution["outputs"] is None:
|
||||||
return []
|
return []
|
||||||
source_kno = json.loads(lock_node_execution["outputs"])['json'][0]['retrieve_result']
|
outputs = json.loads(lock_node_execution["outputs"])
|
||||||
|
source_kno = outputs['json'][0]['retrieve_result']
|
||||||
for knowledge in source_kno:
|
for knowledge in source_kno:
|
||||||
document_name = knowledge['metadata']['document_name']
|
document_name = knowledge['metadata']['document_name']
|
||||||
wiki_list.append(document_name.split("/")[-1])
|
wiki_list.append(document_name.split("/")[-1])
|
||||||
@@ -172,6 +156,50 @@ class DifyExporter:
|
|||||||
|
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
def get_query_type_from_service(self, workflow_run_id):
|
||||||
|
"""
|
||||||
|
从HTTP服务获取查询类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: 工作流运行ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
查询类型字符串,如果获取失败则返回空字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
url = f"{self.answer_type_url}/query_by_workflow_id?workflow_run_id={workflow_run_id}"
|
||||||
|
response = requests.get(url, timeout=2)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if data.get("data") and len(data["data"]) > 0:
|
||||||
|
return data["data"][0]["query_type"]
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取查询类型时出错: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def get_dislike_reason_from_service(self, workflow_run_id):
|
||||||
|
"""
|
||||||
|
从HTTP服务获取查询类型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
workflow_run_id: 工作流运行ID
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
查询类型字符串,如果获取失败则返回空字符串
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
url = f"{self.answer_type_url}/dislike_by_workflow_id?workflow_run_id={workflow_run_id}"
|
||||||
|
response = requests.get(url, timeout=2)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
if data.get("data") and len(data["data"]) > 0:
|
||||||
|
return data["data"][0]["dislike_reason"]
|
||||||
|
return ""
|
||||||
|
except Exception as e:
|
||||||
|
print(f"获取查询类型时出错: {e}")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
def extract_message_info(self, message):
|
def extract_message_info(self, message):
|
||||||
"""
|
"""
|
||||||
@@ -210,10 +238,11 @@ class DifyExporter:
|
|||||||
wiki_list = list(set(wiki_list))
|
wiki_list = list(set(wiki_list))
|
||||||
wiki_list_str = "\n".join(wiki_list)
|
wiki_list_str = "\n".join(wiki_list)
|
||||||
rating = self.dify_pgsql.get_message_rating(msg_id)
|
rating = self.dify_pgsql.get_message_rating(msg_id)
|
||||||
# 直接通过字典键获取query_type
|
|
||||||
|
# 从HTTP服务获取query_type
|
||||||
workflow_run_id = message['workflow_run_id']
|
workflow_run_id = message['workflow_run_id']
|
||||||
query_type = self.query_logs.get(workflow_run_id, {}).get('query_type', "")
|
query_type = self.get_query_type_from_service(workflow_run_id)
|
||||||
|
dislike_reason = self.get_dislike_reason_from_service(workflow_run_id)
|
||||||
return {
|
return {
|
||||||
"msg_id": msg_id,
|
"msg_id": msg_id,
|
||||||
"提问": msg_query,
|
"提问": msg_query,
|
||||||
@@ -224,6 +253,7 @@ class DifyExporter:
|
|||||||
"评价": rating,
|
"评价": rating,
|
||||||
"问题分类": query_type,
|
"问题分类": query_type,
|
||||||
"检索到的词条": wiki_list_str,
|
"检索到的词条": wiki_list_str,
|
||||||
|
"点踩原因": dislike_reason
|
||||||
}
|
}
|
||||||
|
|
||||||
def process_conversations(self):
|
def process_conversations(self):
|
||||||
@@ -277,7 +307,7 @@ class DifyExporter:
|
|||||||
# 设置列的顺序
|
# 设置列的顺序
|
||||||
columns_order = [
|
columns_order = [
|
||||||
"msg_id","当前软件", "提问", "回答", "提问人", "提问时间",
|
"msg_id","当前软件", "提问", "回答", "提问人", "提问时间",
|
||||||
"评价", "问题分类", "检索到的词条"
|
"评价", "问题分类", "检索到的词条", "点踩原因"
|
||||||
]
|
]
|
||||||
|
|
||||||
# 确保所有列都存在,如果不存在则添加空列
|
# 确保所有列都存在,如果不存在则添加空列
|
||||||
@@ -315,7 +345,7 @@ class DifyExporter:
|
|||||||
"评价": 10,
|
"评价": 10,
|
||||||
"问题分类": 20,
|
"问题分类": 20,
|
||||||
"检索到的词条": 40,
|
"检索到的词条": 40,
|
||||||
"备注": 40
|
"点踩原因": 20
|
||||||
}
|
}
|
||||||
|
|
||||||
# 应用列宽设置
|
# 应用列宽设置
|
||||||
@@ -343,10 +373,6 @@ class DifyExporter:
|
|||||||
如果在初始化时指定了start_date或end_date,则只会导出指定日期时间范围内的消息
|
如果在初始化时指定了start_date或end_date,则只会导出指定日期时间范围内的消息
|
||||||
数据库中的时间是UTC+0时区,会自动转换为UTC+8时区进行过滤和显示
|
数据库中的时间是UTC+0时区,会自动转换为UTC+8时区进行过滤和显示
|
||||||
"""
|
"""
|
||||||
# 加载查询日志
|
|
||||||
self.load_query_logs(self.query_log_file)
|
|
||||||
self.load_query_logs("data/query_logs/answer_type_logs_071409.json")
|
|
||||||
|
|
||||||
# 处理会话数据
|
# 处理会话数据
|
||||||
self.process_conversations()
|
self.process_conversations()
|
||||||
|
|
||||||
@@ -381,11 +407,9 @@ if __name__ == "__main__":
|
|||||||
parser = argparse.ArgumentParser(description='Dify数据导出工具')
|
parser = argparse.ArgumentParser(description='Dify数据导出工具')
|
||||||
parser.add_argument('--output', '-o', type=str, default="data/excel/dify_export.xlsx",
|
parser.add_argument('--output', '-o', type=str, default="data/excel/dify_export.xlsx",
|
||||||
help='输出Excel文件路径')
|
help='输出Excel文件路径')
|
||||||
parser.add_argument('--app_id', '-a', type=str, default=None,
|
parser.add_argument('--app_id', '-a', type=str, default="6218c4fd-bba3-4f5b-9fb5-61585d8eee51",
|
||||||
help='Dify应用ID')
|
help='Dify应用ID')
|
||||||
parser.add_argument('--query_log_file', '-q', type=str, default="data/query_logs/answer_type_logs.json",
|
parser.add_argument('--start_date', '-s', type=str, default="2025-07-30 00",
|
||||||
help='查询日志文件路径')
|
|
||||||
parser.add_argument('--start_date', '-s', type=str, default="2025-07-14 00",
|
|
||||||
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
|
help='开始日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 14表示2025年7月8日14时(UTC+8时区)')
|
||||||
parser.add_argument('--end_date', '-e', type=str, default=None,
|
parser.add_argument('--end_date', '-e', type=str, default=None,
|
||||||
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')
|
help='结束日期时间,格式为YYYY-MM-DD HH,例如2025-07-08 18表示2025年7月8日18时(UTC+8时区)')
|
||||||
@@ -403,7 +427,6 @@ if __name__ == "__main__":
|
|||||||
# 创建导出器实例
|
# 创建导出器实例
|
||||||
exporter = DifyExporter(
|
exporter = DifyExporter(
|
||||||
app_id=args.app_id,
|
app_id=args.app_id,
|
||||||
query_log_file=args.query_log_file,
|
|
||||||
start_date=args.start_date,
|
start_date=args.start_date,
|
||||||
end_date=args.end_date
|
end_date=args.end_date
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user