#!/usr/bin/env python # -*- coding: utf-8 -*- import psycopg2 import os import json from concurrent.futures import ThreadPoolExecutor, as_completed import sys sys.path.append(os.getcwd()) from rag2_0.dify.dify_client import ChatClient from pydantic import BaseModel, Field from langchain.output_parsers import PydanticOutputParser from threading import Lock class ContentSource(BaseModel): score: int = Field(description="相关性分数") reason: str = Field(description="评分理由") class PgSql: """ 用于连接和操作 PostgreSQL 数据库的类。 该类封装了数据库连接、关闭连接以及执行特定查询的方法, 主要用于从 Dify 应用相关的表中获取数据。 """ def __init__(self): self.pg_sql_lock = Lock() """ 初始化 PgSql 实例并建立数据库连接。 """ self.connection = None self.connect_sql() def connect_sql(self): """ 连接到 PostgreSQL 数据库。 使用预定义的凭据连接到 'dify' 数据库。 如果连接失败,会抛出异常。 """ try: self.DIFY_PG_USER = os.getenv("DIFY_PG_USER") self.DIFY_PG_PASSWORD = os.getenv("DIFY_PG_PASSWORD") self.DIFY_PG_HOST = os.getenv("DIFY_PG_HOST") self.DIFY_PG_PORT = os.getenv("DIFY_PG_PORT") self.DIFY_PG_DATABASE = os.getenv("DIFY_PG_DATABASE") # 连接数据库 self.connection = psycopg2.connect( user=self.DIFY_PG_USER, password=self.DIFY_PG_PASSWORD, host=self.DIFY_PG_HOST, port=self.DIFY_PG_PORT, database=self.DIFY_PG_DATABASE ) except (Exception, psycopg2.Error) as error: raise Exception(f"Error while connecting to PostgreSQL: {error}") def close_connection(self): """ 关闭当前的 PostgreSQL 数据库连接。 如果存在活动的连接,则关闭它。 """ with self.pg_sql_lock: if self.connection: self.connection.close() self.connection = None def get_appinfo(self, appid:str)->dict | None: """ 根据应用 ID 从 'apps' 表中获取应用信息。 Args: appid: 目标应用的 ID。 Returns: 一个字典,其中键是列名,值是对应的应用数据。 如果未找到应用或发生错误,则返回 None。 """ with self.pg_sql_lock: try: with self.connection.cursor() as cursor: cursor.execute( """ SELECT * FROM apps WHERE id = %s """, (appid,) ) result = cursor.fetchone() if result: colnames = [desc[0] for desc in cursor.description] return dict(zip(colnames, result)) return None except (Exception, psycopg2.Error) as error: raise Exception(f"Error while getting tenant_id by appid: {error}") def get_messages_info(self, appid:str, query:str)->dict | None: """ 根据应用 ID 和查询内容从 'messages' 表中获取消息信息。 Args: appid: 目标应用的 ID。 query: 用户查询的具体内容。 Returns: 一个字典,其中键是列名,值是对应的消息数据。 如果未找到消息或发生错误,则返回 None。 """ with self.pg_sql_lock: try: with self.connection.cursor() as cursor: cursor.execute( """ SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC """, (appid, query) ) result = cursor.fetchone() if result: colnames = [desc[0] for desc in cursor.description] return dict(zip(colnames, result)) return None except (Exception, psycopg2.Error) as error: raise Exception(f"Error while getting messages_info: {error}") def get_messages_info_by_id(self, message_id:str)->dict | None: """ 根据消息 ID 从 'messages' 表中获取消息信息。 """ with self.pg_sql_lock: try: with self.connection.cursor() as cursor: cursor.execute( """ SELECT * FROM messages WHERE id = %s """, (message_id, ) ) result = cursor.fetchone() if result: colnames = [desc[0] for desc in cursor.description] return dict(zip(colnames, result)) return None except (Exception, psycopg2.Error) as error: raise Exception(f"Error while getting messages_info by id: {error}") def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None: """ 根据工作流运行 ID 从 'workflow_node_executions' 表中获取节点执行信息。 Args: workflow_run_id: 目标工作流运行的 ID。 Returns: 一个字典,其中键是列名,值是对应的节点执行数据。 如果未找到执行信息或发生错误,则返回 None。 """ with self.pg_sql_lock: try: with self.connection.cursor() as cursor: cursor.execute( """ SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s """, (workflow_run_id,) ) result = cursor.fetchall() if result: colnames = [desc[0] for desc in cursor.description] return [dict(zip(colnames, row)) for row in result] return None except (Exception, psycopg2.Error) as error: raise Exception(f"Error while getting workflow_node_executions_info: {error}") def get_app_conversations(self, appid:str)->list[str] | None: """ 根据应用 ID 从 'conversations' 表中获取应用会话信息。 """ with self.pg_sql_lock: try: with self.connection.cursor() as cursor: cursor.execute( """ SELECT DISTINCT conversation_id FROM messages WHERE app_id = %s AND invoke_from != 'debugger'; """, (appid,) ) result = cursor.fetchall() if result: colnames = [desc[0] for desc in cursor.description] return [dict(zip(colnames, row)) for row in result] return None except (Exception, psycopg2.Error) as error: raise Exception(f"Error while getting app_conversations: {error}") def get_conversation_messages(self, conversation_id:str)->list[dict] | None: """ 根据会话 ID 从 'messages' 表中获取会话消息信息。 """ with self.pg_sql_lock: try: with self.connection.cursor() as cursor: cursor.execute( """ SELECT * FROM messages WHERE conversation_id = %s AND status = 'normal' """, (conversation_id,) ) result = cursor.fetchall() if result: colnames = [desc[0] for desc in cursor.description] return [dict(zip(colnames, row)) for row in result] return None except (Exception, psycopg2.Error) as error: raise Exception(f"Error while getting conversation_messages: {error}") def get_message_rating(self, msg_id): """ 通过msg_id从message_feedbacks中找到对应的rating。 :param msg_id: 消息ID (UUID格式) :return: rating 字符串 """ with self.pg_sql_lock: rating = None try: with self.connection.cursor() as cursor: # 构建查询语句 cursor.execute(""" SELECT rating FROM message_feedbacks WHERE message_id = %s """, (msg_id,)) # 执行查询 row = cursor.fetchone() if row: rating = row[0] except (Exception, psycopg2.Error) as error: raise Exception(f"Error while getting conversation_messages: {error}") return rating def get_workflow_run_info(self, workflow_run_id): """ 通过msg_id从message_feedbacks中找到对应的rating。 :param msg_id: 消息ID (UUID格式) :return: rating 字符串 """ with self.pg_sql_lock: rating = None try: with self.connection.cursor() as cursor: # 构建查询语句 cursor.execute(""" SELECT * FROM workflow_runs WHERE id=%s; """, (workflow_run_id,)) # 执行查询 result = cursor.fetchone() if result: colnames = [desc[0] for desc in cursor.description] return dict(zip(colnames, result)) except (Exception, psycopg2.Error) as error: raise Exception(f"Error while getting conversation_messages: {error}") return None class DifyTool: """ 提供用于获取 Dify 应用调试信息的工具类。 该类利用 PgSql 类从数据库中检索与特定应用和查询相关的 应用信息、消息详情以及工作流节点执行情况。 """ def __init__(self): self.dify_pgsql = PgSql() def __del__(self): """ 析构函数,在对象被销毁时自动关闭数据库连接。 确保在对象生命周期结束时释放数据库资源。 """ try: self.dify_pgsql.close_connection() except Exception as e: # 析构函数中的异常不应该传播,所以这里只是简单记录 print(f"关闭数据库连接时出错: {e}") def get_message_debug_info_by_id(self, message_id:str)->dict | None: """ 根据消息 ID 从 'messages' 表中获取消息信息。 """ try: messages_info = self.dify_pgsql.get_messages_info_by_id(message_id) if not messages_info: return None workflow_node_executions_info = self.dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id']) if not workflow_node_executions_info: return { "messages_info": messages_info, "workflow_node_executions_info": None } return { "messages_info": messages_info, "workflow_node_executions_info": workflow_node_executions_info } except Exception as e: raise Exception(f"Error in get_message_debug_info_by_id: {e}") def get_message_debug_info_by_query(self, appid:str, query:str)->dict: """ 获取指定应用和查询相关的调试信息。 此静态方法会创建一个临时的 PgSql 实例来查询数据库, 然后聚合应用信息、消息信息和工作流节点执行信息。 Args: appid: 目标应用的 ID。 query: 用户查询的具体内容。 Returns: 一个包含 "appinfo", "messages_info", 和 "workflow_node_executions_info"键的字典,分别对应 查询到的应用数据、消息数据和节点执行数据。 """ try: appinfo = self.dify_pgsql.get_appinfo(appid) if not appinfo: return None messages_info = self.dify_pgsql.get_messages_info(appid, query) if not messages_info: return None workflow_node_executions_info = self.dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id']) if not workflow_node_executions_info: return None return { "appinfo": appinfo, "messages_info": messages_info, "workflow_node_executions_info": workflow_node_executions_info } except Exception as e: raise Exception(f"Error in get_message_debug_info_by_query: {e}") def get_app_conversations(self, appid:str)->list[str] | None: """ 根据应用 ID 从 'conversations' 表中获取应用会话信息。 """ return self.dify_pgsql.get_app_conversations(appid) def get_conversation_messages(self, conversation_id:str): """ 根据会话 ID 从 'messages' 表中获取会话消息信息。 """ return self.dify_pgsql.get_app_conversations(conversation_id) def get_message_rating(self, msg_id): return self.dify_pgsql.get_message_rating(msg_id) def get_workflow_run_info(self, workflow_run_id): return self.dify_pgsql.get_workflow_run_info(workflow_run_id) class BaseWorkflowChat: """ 工作流对话基类,封装了与Dify API交互的基本功能 """ def __init__(self, api_key: str, base_url: str): """ 初始化工作流对话基类 Args: api_key: Dify API的密钥 base_url: Dify API的基础URL """ self.chat_client = ChatClient(api_key=api_key, base_url=base_url) self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource) self.dify_tool = DifyTool() def __del__(self): """ 析构函数,在对象被销毁时自动关闭数据库连接。 确保在对象生命周期结束时释放数据库资源。 """ # DifyTool类已经在其__del__方法中关闭了数据库连接,无需在此重复调用 pass def create_chat_message(self, query: str): """ 创建聊天消息 Args: query: 问题内容 Returns: tuple: (聊天响应, 消息ID) """ try: response = self.chat_client.create_chat_message(inputs={}, query=query, user="AutoTestDifyChat").json() return response, response["message_id"] except Exception as e: raise e def calculate_score(self, query: str, content: str) -> int: """ 使用LLM判断query与content之间的相关性分数 Args: query (str): 用户问题 content (str): 检索内容 Returns: int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败 """ from rag2_0.tool.ModelTool import OpenAiLLM try: prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。 【评分标准】 10分:完全契合,主题/意图完全一致且涵盖所有关键信息 8-9分:高度相关,核心要素匹配但存在少量信息缺失 6-7分:部分相关,涉及相同主题但存在重要信息缺失 4-5分:弱相关,仅次要信息点匹配 1-3分:完全不相关或信息冲突 【评估维度】 1. 主题一致性:核心主题/意图的匹配程度 2. 内容覆盖度:是否涵盖query的关键要素 3. 信息准确性:是否存在矛盾/错误信息 4. 细节丰富度:是否提供query要求的详细信息 【输出格式】 {{ "score": 评分, "reason": "简明扼要的评分理由(中文)" }} 【示例】 query: "新冠疫苗的常见副作用" 内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)" 输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}} 现在评估: query: "{query}" content: "{content}" """ api_key = os.getenv("OPENAI_API_KEY") base_url = os.getenv("OPENAI_API_BASE") model = os.getenv("LLM_MODEL_NAME") llm = OpenAiLLM(api_key=api_key, base_url=base_url, model=model) response = llm.invoke(user_prompt=prompt, need_retry=True) # 解析JSON响应 try: parsed_output = self.content_source_parser.parse(response.content) return parsed_output.score except Exception as e: return -1 except Exception as e: return -1 def get_retrieve_info(self, query: str, outputs: list[dict], reranker_sorce_info:list) -> tuple: """ 获取检索信息并计算分数 Args: query (str): 用户问题 outputs (dict): 检索输出结果 Returns: tuple: (检索内容列表, 最高分, 最低分, 平均分) """ max_score = 0 min_score = 10 total_score = 0 valid_scores = 0 retrieve_title = [] segmentid_to_title = { result["segment_id"]:result["title"].split("/")[-1] for result in outputs} # 使用线程池并发计算分数 with ThreadPoolExecutor() as executor: # 创建任务列表 future_to_content = {} for result in outputs: content = result["segment_content"].strip() segment_id = result["segment_id"].strip() future = executor.submit(self.calculate_score, query=query, content=content) future_to_content[future] = (content, segment_id) # 收集结果 for future in as_completed(future_to_content): content, segment_id = future_to_content[future] score = future.result() content_title = segmentid_to_title[segment_id] if score != -1: max_score = max(max_score, score) min_score = min(min_score, score) total_score += score valid_scores += 1 if content_title: current_score = next((cur_source_info["score"] for cur_source_info in reranker_sorce_info if cur_source_info["segment_id"] == segment_id), None) retrieve_title.append(content_title + f"--LLM得分({score}分)--重排得分({current_score:.2f}分)") avg_score = total_score / valid_scores if valid_scores > 0 else 0 return retrieve_title, max_score, min_score, avg_score class NewWorkflowChat(BaseWorkflowChat): """ 新工作流对话类,用于调用新工作流发送对话并解析获取相关数据 """ def __init__(self, api_key: str, base_url: str): super().__init__(api_key, base_url) def process_question(self, query: str) -> dict: """ 处理问题,获取新工作流的回答和相关信息 Args: query: 问题内容 Returns: dict: 包含问题、回答和相关信息的字典 """ response, message_id = self.create_chat_message(query) if isinstance(response, str) and response.startswith("error:"): raise RuntimeError(f"create_chat_message 出错:{response}") answer = response["answer"] workflow_info = self.get_workflow_info(query, message_id) if workflow_info is None: return None result = { "问题": query, "新流程答案": answer, "新问题改写": workflow_info["问题改写"], "新问题分类": workflow_info["问题分类"], "槽点信息": workflow_info["槽点信息"], "新检索词条": workflow_info["检索词条"], "message_id":message_id } return result def get_workflow_info(self, query: str, message_id: str) -> dict: """ 获取新工作流的问题分类和检索信息 Args: query (str): 用户问题 message_id (str): 新工作流的消息ID Returns: dict: 包含问题分类结果的字典 """ retrieve_title = [] retrieve_content = [] max_score = 0 min_score = 0 avg_score = 0 rewrite_query = "" vertical_classification = "" sub_classification = "" slot_info = "" reranker_sorce=[] try: # 先取出重排得分 message_info = self.dify_tool.get_message_debug_info_by_id(message_id=message_id) for workflow_node in message_info["workflow_node_executions_info"]: if workflow_node["title"] == "提取处理后的知识": retrieve_outputs = json.loads(workflow_node["outputs"])["source_kno"] reranker_sorce = [{"score":result["metadata"]["score"], "segment_id":result["metadata"]["segment_id"]} for result in retrieve_outputs] break for workflow_node in message_info["workflow_node_executions_info"]: if workflow_node["title"] == "提取处理后的知识": outputs = json.loads(workflow_node["outputs"])["knowledge_list"] retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs, reranker_sorce_info=reranker_sorce) elif workflow_node["title"] == "意图识别结果解析": outputs = json.loads(workflow_node["outputs"]) rewrite_query = outputs["optimize_query"] llm_result_json = json.loads(workflow_node['inputs'])["llm_result"] json_result = json.loads(llm_result_json) vertical_classification = json_result['vertical_classification'] sub_classification = json_result['sub_classification'] slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2) except Exception as e: raise e retrieve_content = "" if len(reranker_sorce)==0: retrieve_content="未检索知识库" elif len(reranker_sorce) > 0 and len(retrieve_title)==0: retrieve_content = "知识与提问不相关,被丢弃" else: retrieve_content = "\n".join(retrieve_title) return { "问题改写": rewrite_query, "检索词条": retrieve_content, "问题分类": f"{vertical_classification} - {sub_classification}", "槽点信息": slot_info, } class OldWorkFlowChat(BaseWorkflowChat): """ 旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据 """ def __init__(self, api_key: str, base_url: str): super().__init__(api_key, base_url) def process_question(self, query: str) -> dict: """ 处理问题,获取旧工作流的回答和相关信息 Args: query: 问题内容 Returns: dict: 包含问题、回答和相关信息的字典 """ response, message_id = self.create_chat_message(query) if isinstance(response, str) and response.startswith("error:"): return None answer = response["answer"] workflow_info = self.get_workflow_info(query, message_id) if workflow_info is None: return None result = { "问题": query, "旧流程答案": answer, "旧问题改写": workflow_info["问题改写"], "旧检索词条": workflow_info["检索词条"], "message_id":message_id } return result def get_workflow_info(self, query: str, message_id: str) -> dict: """ 获取旧工作流的问题改写和检索信息 Args: query (str): 用户问题 message_id (str): 旧工作流的消息ID Returns: dict: 包含问题改写和检索信息的字典 """ retrieve_title = [] retrieve_content = [] max_score = 0 min_score = 0 avg_score = 0 rewrite_query = "" try: message_info = self.dify_tool.get_message_debug_info_by_id(message_id=message_id) for workflow_node in message_info["workflow_node_executions_info"]: if workflow_node["title"] == "知识检索结果后处理": outputs = json.loads(workflow_node["outputs"]) retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs) retrieve_content = outputs["result"] elif workflow_node["title"] == "问题优化结果解析": outputs = json.loads(workflow_node["outputs"]) rewrite_query = outputs["optimize_query"] except Exception as e: return None return { "问题改写": rewrite_query, "检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库", } if __name__ == "__main__": pass