From c28fe97dcccac66d5d1ecd04da094c28e43cc8d9 Mon Sep 17 00:00:00 2001 From: ouyangyouzhang Date: Fri, 20 Jun 2025 16:32:51 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=94=B9=E9=A1=B9=E7=9B=AE=E5=90=8D?= =?UTF-8?q?=E7=A7=B0=E4=B8=BA"rag2=5F0"=EF=BC=8C=E6=9B=B4=E6=96=B0?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E4=BE=9D=E8=B5=96=E9=A1=B9=E5=90=8D=E7=A7=B0?= =?UTF-8?q?=EF=BC=8C=E5=A2=9E=E5=BC=BAPgSql=E7=B1=BB=E7=9A=84=E7=BA=BF?= =?UTF-8?q?=E7=A8=8B=E5=AE=89=E5=85=A8=E6=80=A7=EF=BC=8C=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E9=94=81=E6=9C=BA=E5=88=B6=E4=BB=A5=E7=A1=AE=E4=BF=9D=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E8=BF=9E=E6=8E=A5=E7=9A=84=E5=AE=89=E5=85=A8?= =?UTF-8?q?=E5=85=B3=E9=97=AD=EF=BC=8C=E5=90=8C=E6=97=B6=E4=BC=98=E5=8C=96?= =?UTF-8?q?DifyTool=E7=B1=BB=E7=9A=84=E8=B5=84=E6=BA=90=E7=AE=A1=E7=90=86?= =?UTF-8?q?=EF=BC=8C=E7=A1=AE=E4=BF=9D=E5=9C=A8=E5=AF=B9=E8=B1=A1=E9=94=80?= =?UTF-8?q?=E6=AF=81=E6=97=B6=E8=87=AA=E5=8A=A8=E5=85=B3=E9=97=AD=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=BA=93=E8=BF=9E=E6=8E=A5=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 2 +- rag2_0/dify/dify_client/__init__.py | 5 +- rag2_0/dify/dify_tool.py | 189 ++++++++++++++++------------ rag2_0/dify/test_dify_chatapi.py | 12 +- uv.lock | 2 +- 5 files changed, 123 insertions(+), 87 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 01f72de..b0bd223 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [project] -name = "queryrewrite" +name = "rag2_0" version = "0.1.0" description = "Add your description here" readme = "README.md" diff --git a/rag2_0/dify/dify_client/__init__.py b/rag2_0/dify/dify_client/__init__.py index 6fa9d19..7e4c8c2 100755 --- a/rag2_0/dify/dify_client/__init__.py +++ b/rag2_0/dify/dify_client/__init__.py @@ -1 +1,4 @@ -from dify_client.client import ChatClient, CompletionClient, DifyClient + +__all__ = ["ChatClient", "CompletionClient", "DifyClient"] + +from .client import ChatClient, CompletionClient, DifyClient diff --git a/rag2_0/dify/dify_tool.py b/rag2_0/dify/dify_tool.py index 009e33d..4afe3ce 100755 --- a/rag2_0/dify/dify_tool.py +++ b/rag2_0/dify/dify_tool.py @@ -7,12 +7,13 @@ from concurrent.futures import ThreadPoolExecutor, as_completed from rag2_0.dify.dify_client import ChatClient from pydantic import BaseModel, Field from langchain.output_parsers import PydanticOutputParser - +from threading import Lock class ContentSource(BaseModel): score: int = Field(description="相关性分数") reason: str = Field(description="评分理由") + class PgSql: """ 用于连接和操作 PostgreSQL 数据库的类。 @@ -21,6 +22,7 @@ class PgSql: 主要用于从 Dify 应用相关的表中获取数据。 """ def __init__(self): + self.pg_sql_lock = Lock() """ 初始化 PgSql 实例并建立数据库连接。 """ @@ -53,8 +55,10 @@ class PgSql: 如果存在活动的连接,则关闭它。 """ - if self.connection: - self.connection.close() + with self.pg_sql_lock: + if self.connection: + self.connection.close() + self.connection = None def get_appinfo(self, appid:str)->dict | None: @@ -68,21 +72,22 @@ class PgSql: 一个字典,其中键是列名,值是对应的应用数据。 如果未找到应用或发生错误,则返回 None。 """ - try: - with self.connection.cursor() as cursor: - cursor.execute( - """ - SELECT * FROM apps WHERE id = %s - """, - (appid,) - ) - result = cursor.fetchone() - if result: - colnames = [desc[0] for desc in cursor.description] - return dict(zip(colnames, result)) - return None - except (Exception, psycopg2.Error) as error: - raise Exception(f"Error while getting tenant_id by appid: {error}") + with self.pg_sql_lock: + try: + with self.connection.cursor() as cursor: + cursor.execute( + """ + SELECT * FROM apps WHERE id = %s + """, + (appid,) + ) + result = cursor.fetchone() + if result: + colnames = [desc[0] for desc in cursor.description] + return dict(zip(colnames, result)) + return None + except (Exception, psycopg2.Error) as error: + raise Exception(f"Error while getting tenant_id by appid: {error}") def get_messages_info(self, appid:str, query:str)->dict | None: @@ -97,41 +102,43 @@ class PgSql: 一个字典,其中键是列名,值是对应的消息数据。 如果未找到消息或发生错误,则返回 None。 """ - try: - with self.connection.cursor() as cursor: - cursor.execute( - """ - SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC - """, - (appid, query) - ) - result = cursor.fetchone() - if result: - colnames = [desc[0] for desc in cursor.description] - return dict(zip(colnames, result)) - return None - except (Exception, psycopg2.Error) as error: - raise Exception(f"Error while getting messages_info: {error}") + with self.pg_sql_lock: + try: + with self.connection.cursor() as cursor: + cursor.execute( + """ + SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC + """, + (appid, query) + ) + result = cursor.fetchone() + if result: + colnames = [desc[0] for desc in cursor.description] + return dict(zip(colnames, result)) + return None + except (Exception, psycopg2.Error) as error: + raise Exception(f"Error while getting messages_info: {error}") def get_messages_info_by_id(self, message_id:str)->dict | None: """ 根据消息 ID 从 'messages' 表中获取消息信息。 """ - try: - with self.connection.cursor() as cursor: - cursor.execute( - """ - SELECT * FROM messages WHERE id = %s - """, - (message_id, ) - ) - result = cursor.fetchone() - if result: - colnames = [desc[0] for desc in cursor.description] - return dict(zip(colnames, result)) - return None - except (Exception, psycopg2.Error) as error: - raise Exception(f"Error while getting messages_info by id: {error}") + with self.pg_sql_lock: + try: + with self.connection.cursor() as cursor: + cursor.execute( + """ + SELECT * FROM messages WHERE id = %s + """, + (message_id, ) + ) + result = cursor.fetchone() + if result: + colnames = [desc[0] for desc in cursor.description] + return dict(zip(colnames, result)) + return None + except (Exception, psycopg2.Error) as error: + raise Exception(f"Error while getting messages_info by id: {error}") def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None: """ @@ -144,21 +151,22 @@ class PgSql: 一个字典,其中键是列名,值是对应的节点执行数据。 如果未找到执行信息或发生错误,则返回 None。 """ - try: - with self.connection.cursor() as cursor: - cursor.execute( - """ - SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s - """, - (workflow_run_id,) - ) + with self.pg_sql_lock: + try: + with self.connection.cursor() as cursor: + cursor.execute( + """ + SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s + """, + (workflow_run_id,) + ) result = cursor.fetchall() if result: colnames = [desc[0] for desc in cursor.description] return [dict(zip(colnames, row)) for row in result] return None - except (Exception, psycopg2.Error) as error: - raise Exception(f"Error while getting workflow_node_executions_info: {error}") + except (Exception, psycopg2.Error) as error: + raise Exception(f"Error while getting workflow_node_executions_info: {error}") class DifyTool: """ @@ -167,17 +175,30 @@ class DifyTool: 该类利用 PgSql 类从数据库中检索与特定应用和查询相关的 应用信息、消息详情以及工作流节点执行情况。 """ - @staticmethod - def get_message_debug_info_by_id(message_id:str)->dict | None: + + def __init__(self): + self.dify_pgsql = PgSql() + + def __del__(self): + """ + 析构函数,在对象被销毁时自动关闭数据库连接。 + 确保在对象生命周期结束时释放数据库资源。 + """ + try: + self.dify_pgsql.close_connection() + except Exception as e: + # 析构函数中的异常不应该传播,所以这里只是简单记录 + print(f"关闭数据库连接时出错: {e}") + + def get_message_debug_info_by_id(self, message_id:str)->dict | None: """ 根据消息 ID 从 'messages' 表中获取消息信息。 """ - dify_pgsql = PgSql() try: - messages_info = dify_pgsql.get_messages_info_by_id(message_id) + messages_info = self.dify_pgsql.get_messages_info_by_id(message_id) if not messages_info: return None - workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id']) + workflow_node_executions_info = self.dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id']) if not workflow_node_executions_info: return None return { @@ -186,12 +207,8 @@ class DifyTool: } except Exception as e: raise Exception(f"Error in get_message_debug_info_by_id: {e}") - finally: - dify_pgsql.close_connection() - - @staticmethod - def get_message_debug_info_by_query(appid:str, query:str)->dict: + def get_message_debug_info_by_query(self, appid:str, query:str)->dict: """ 获取指定应用和查询相关的调试信息。 @@ -207,15 +224,14 @@ class DifyTool: "workflow_node_executions_info"键的字典,分别对应 查询到的应用数据、消息数据和节点执行数据。 """ - dify_pgsql = PgSql() try: - appinfo = dify_pgsql.get_appinfo(appid) + appinfo = self.dify_pgsql.get_appinfo(appid) if not appinfo: - return None - messages_info = dify_pgsql.get_messages_info(appid, query) + return None + messages_info = self.dify_pgsql.get_messages_info(appid, query) if not messages_info: return None - workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id']) + workflow_node_executions_info = self.dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id']) if not workflow_node_executions_info: return None return { @@ -225,8 +241,6 @@ class DifyTool: } except Exception as e: raise Exception(f"Error in get_message_debug_info_by_query: {e}") - finally: - dify_pgsql.close_connection() class BaseWorkflowChat: """ @@ -242,6 +256,14 @@ class BaseWorkflowChat: """ self.chat_client = ChatClient(api_key=api_key, base_url=base_url) self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource) + self.dify_tool = DifyTool() + + def __del__(self): + """ + 析构函数,在对象被销毁时自动关闭数据库连接。 + 确保在对象生命周期结束时释放数据库资源。 + """ + self.dify_tool.close_connection() def create_chat_message(self, query: str): """ @@ -369,6 +391,9 @@ class NewWorkflowChat(BaseWorkflowChat): """ 新工作流对话类,用于调用新工作流发送对话并解析获取相关数据 """ + def __init__(self, api_key: str, base_url: str): + super().__init__(api_key, base_url) + def process_question(self, query: str) -> dict: """ 处理问题,获取新工作流的回答和相关信息 @@ -425,7 +450,7 @@ class NewWorkflowChat(BaseWorkflowChat): reranker_sorce=[] try: # 先取出重排得分 - message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id) + message_info = self.dify_tool.get_message_debug_info_by_id(message_id=message_id) for workflow_node in message_info["workflow_node_executions_info"]: if workflow_node["title"] == "软件知识检索聚合": retrieve_outputs = json.loads(workflow_node["inputs"])["result"] @@ -470,6 +495,10 @@ class OldWorkFlowChat(BaseWorkflowChat): """ 旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据 """ + + def __init__(self, api_key: str, base_url: str): + super().__init__(api_key, base_url) + def process_question(self, query: str) -> dict: """ 处理问题,获取旧工作流的回答和相关信息 @@ -520,7 +549,7 @@ class OldWorkFlowChat(BaseWorkflowChat): rewrite_query = "" try: - message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id) + message_info = self.dify_tool.get_message_debug_info_by_id(message_id=message_id) for workflow_node in message_info["workflow_node_executions_info"]: if workflow_node["title"] == "知识检索结果后处理": outputs = json.loads(workflow_node["outputs"]) @@ -538,8 +567,4 @@ class OldWorkFlowChat(BaseWorkflowChat): } if __name__ == "__main__": - try: - result = DifyTool.get_message_debug_info_by_query("ccf92b97-2789-4a3f-90e0-135a869a37c5", "电力建设计价通软件,导入结算后没有暂列金怎么办?要手动添加么?") - print(result) - except Exception as e: - print(f"执行出错: {e}") + pass diff --git a/rag2_0/dify/test_dify_chatapi.py b/rag2_0/dify/test_dify_chatapi.py index 0075faa..aba5179 100755 --- a/rag2_0/dify/test_dify_chatapi.py +++ b/rag2_0/dify/test_dify_chatapi.py @@ -69,6 +69,14 @@ class DifyComparisonTester: else: self.wiki_excel = None + self.dify_tool = DifyTool() + + def __del__(self): + """ + 析构函数,在对象被销毁时自动关闭数据库连接。 + 确保在对象生命周期结束时释放数据库资源。 + """ + self.dify_tool.close_connection() def get_llm(self): api_key = os.getenv("OPENAI_API_KEY") @@ -381,7 +389,7 @@ content: "{content}" """ try: # 使用DifyTool直接获取消息信息 - new_message_info = DifyTool.get_message_debug_info_by_id(message_id=new_message_id) + new_message_info = self.dify_tool.get_message_debug_info_by_id(message_id=new_message_id) # 初始化变量 retrieve_title = [] @@ -428,7 +436,7 @@ content: "{content}" """ try: # 使用DifyTool直接获取消息信息 - old_message_info = DifyTool.get_message_debug_info_by_id(message_id=old_message_id) + old_message_info = self.dify_tool.get_message_debug_info_by_id(message_id=old_message_id) # 初始化变量 retrieve_title = [] diff --git a/uv.lock b/uv.lock index 1ed9ea0..a085bd9 100644 --- a/uv.lock +++ b/uv.lock @@ -1416,7 +1416,7 @@ wheels = [ ] [[package]] -name = "queryrewrite" +name = "rag2-0" version = "0.1.0" source = { virtual = "." } dependencies = [