调整PGSQL相关代码逻辑

This commit is contained in:
2025-08-29 11:49:07 +08:00
parent 78dc1673aa
commit d69fe2b8d1
3 changed files with 261 additions and 260 deletions
+5 -5
View File
@@ -18,7 +18,7 @@ from langchain_core.output_parsers import JsonOutputParser
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from rag2_0.dify.dify_client import ChatClient from rag2_0.dify.dify_client import ChatClient
from rag2_0.tool.ModelTool import OpenAiLLM from rag2_0.tool.ModelTool import OpenAiLLM
from rag2_0.dify.dify_tool import PgSql, DifyTool from rag2_0.dify.dify_tool import DifyTool
from rag2_0.dify.export_new_dify import DifyExporter from rag2_0.dify.export_new_dify import DifyExporter
load_dotenv() load_dotenv()
# 创建日志目录 # 创建日志目录
@@ -104,7 +104,7 @@ class DifyCompareTest:
answer = answer.split("----------------------------------------")[0].strip() answer = answer.split("----------------------------------------")[0].strip()
if len(answer) == 0: if len(answer) == 0:
raise Exception(f"回答为空: {result}") raise Exception(f"回答为空: {result}")
if old_answer: if isinstance(old_answer, str) and len(old_answer) > 0:
judge_result = self.llm_judge_answer(old_answer=old_answer, now_answer=answer) judge_result = self.llm_judge_answer(old_answer=old_answer, now_answer=answer)
else: else:
judge_result="" judge_result=""
@@ -149,8 +149,8 @@ class DifyCompareTest:
result_row["检索到的词条"] = '' result_row["检索到的词条"] = ''
return index, result_row return index, result_row
if "回答" in row: if "参考回答" in row:
old_answer = row["回答"] old_answer = row["参考回答"]
else: else:
old_answer = "" old_answer = ""
@@ -259,7 +259,7 @@ if __name__ == "__main__":
# 处理第一个文件 # 处理第一个文件
excel_files = [ excel_files = [
("data/excel/第一轮的专业问题.xlsx", "data/excel/第一轮的专业问题_dify.xlsx"), ("data/excel/第5轮-软件问题-点踩结果.xlsx", "data/excel/第5轮-软件问题-点踩结果_dify.xlsx"),
# ("data/excel/有知识的.xlsx", "data/excel/有知识的_问答测试.xlsx") # ("data/excel/有知识的.xlsx", "data/excel/有知识的_问答测试.xlsx")
] ]
+4 -2
View File
@@ -17,7 +17,10 @@ class ContentSource(BaseModel):
reason: str = Field(description="评分理由") reason: str = Field(description="评分理由")
class PgSql:
class DifyTool:
class PgSql:
""" """
用于连接和操作 PostgreSQL 数据库的类。 用于连接和操作 PostgreSQL 数据库的类。
@@ -268,7 +271,6 @@ class PgSql:
raise Exception(f"Error while getting conversation_messages: {error}") raise Exception(f"Error while getting conversation_messages: {error}")
return None return None
class DifyTool:
""" """
提供用于获取 Dify 应用调试信息的工具类。 提供用于获取 Dify 应用调试信息的工具类。
+4 -5
View File
@@ -7,7 +7,7 @@ import pandas as pd
import sys import sys
sys.path.append(os.getcwd()) sys.path.append(os.getcwd())
from rag2_0.dify.dify_tool import PgSql, DifyTool from rag2_0.dify.dify_tool import DifyTool
import requests import requests
@@ -38,7 +38,6 @@ class DifyExporter:
self.end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d %H") if end_date else None self.end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d %H") if end_date else None
# 初始化工具类 # 初始化工具类
self.dify_pgsql = PgSql()
self.dify_tool = DifyTool() self.dify_tool = DifyTool()
# 初始化数据存储 # 初始化数据存储
@@ -237,7 +236,7 @@ class DifyExporter:
else: else:
wiki_list = list(set(wiki_list)) wiki_list = list(set(wiki_list))
wiki_list_str = "\n".join(wiki_list) wiki_list_str = "\n".join(wiki_list)
rating = self.dify_pgsql.get_message_rating(msg_id) rating = self.dify_tool.get_message_rating(msg_id)
# 从HTTP服务获取query_type # 从HTTP服务获取query_type
workflow_run_id = message['workflow_run_id'] workflow_run_id = message['workflow_run_id']
@@ -266,9 +265,9 @@ class DifyExporter:
Returns: Returns:
处理后的消息信息列表 处理后的消息信息列表
""" """
conversations = self.dify_pgsql.get_app_conversations(appid=self.app_id) conversations = self.dify_tool.get_app_conversations(appid=self.app_id)
for conversation in conversations: for conversation in conversations:
messages = self.dify_pgsql.get_conversation_messages(conversation_id=conversation['conversation_id']) messages = self.dify_tool.get_conversation_messages(conversation_id=conversation['conversation_id'])
message_chain_new = self.process_message_chain(messages) message_chain_new = self.process_message_chain(messages)
if len(message_chain_new) != len(messages): if len(message_chain_new) != len(messages):
print(f"过滤了{len(messages) - len(message_chain_new)}条消息,会话ID{conversation['conversation_id']}") print(f"过滤了{len(messages) - len(message_chain_new)}条消息,会话ID{conversation['conversation_id']}")