调整PGSQL相关代码逻辑
This commit is contained in:
+252
-250
@@ -17,258 +17,260 @@ class ContentSource(BaseModel):
|
||||
reason: str = Field(description="评分理由")
|
||||
|
||||
|
||||
class PgSql:
|
||||
"""
|
||||
用于连接和操作 PostgreSQL 数据库的类。
|
||||
|
||||
该类封装了数据库连接、关闭连接以及执行特定查询的方法,
|
||||
主要用于从 Dify 应用相关的表中获取数据。
|
||||
"""
|
||||
def __init__(self):
|
||||
self.pg_sql_lock = Lock()
|
||||
"""
|
||||
初始化 PgSql 实例并建立数据库连接。
|
||||
"""
|
||||
self.connection = None
|
||||
self.connect_sql()
|
||||
|
||||
def connect_sql(self):
|
||||
"""
|
||||
连接到 PostgreSQL 数据库。
|
||||
|
||||
使用预定义的凭据连接到 'dify' 数据库。
|
||||
如果连接失败,会抛出异常。
|
||||
"""
|
||||
try:
|
||||
self.DIFY_PG_USER = os.getenv("DIFY_PG_USER")
|
||||
self.DIFY_PG_PASSWORD = os.getenv("DIFY_PG_PASSWORD")
|
||||
self.DIFY_PG_HOST = os.getenv("DIFY_PG_HOST")
|
||||
self.DIFY_PG_PORT = os.getenv("DIFY_PG_PORT")
|
||||
self.DIFY_PG_DATABASE = os.getenv("DIFY_PG_DATABASE")
|
||||
# 连接数据库
|
||||
self.connection = psycopg2.connect(
|
||||
user=self.DIFY_PG_USER,
|
||||
password=self.DIFY_PG_PASSWORD,
|
||||
host=self.DIFY_PG_HOST,
|
||||
port=self.DIFY_PG_PORT,
|
||||
database=self.DIFY_PG_DATABASE
|
||||
)
|
||||
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while connecting to PostgreSQL: {error}")
|
||||
|
||||
def close_connection(self):
|
||||
"""
|
||||
关闭当前的 PostgreSQL 数据库连接。
|
||||
|
||||
如果存在活动的连接,则关闭它。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
self.connection = None
|
||||
|
||||
def get_appinfo(self, appid:str)->dict | None:
|
||||
"""
|
||||
根据应用 ID 从 'apps' 表中获取应用信息。
|
||||
|
||||
Args:
|
||||
appid: 目标应用的 ID。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是列名,值是对应的应用数据。
|
||||
如果未找到应用或发生错误,则返回 None。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM apps WHERE id = %s
|
||||
""",
|
||||
(appid,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting tenant_id by appid: {error}")
|
||||
|
||||
def get_messages_info(self, appid:str, query:str)->dict | None:
|
||||
"""
|
||||
根据应用 ID 和查询内容从 'messages' 表中获取消息信息。
|
||||
|
||||
Args:
|
||||
appid: 目标应用的 ID。
|
||||
query: 用户查询的具体内容。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是列名,值是对应的消息数据。
|
||||
如果未找到消息或发生错误,则返回 None。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC
|
||||
""",
|
||||
(appid, query)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting messages_info: {error}")
|
||||
|
||||
def get_messages_info_by_id(self, message_id:str)->dict | None:
|
||||
"""
|
||||
根据消息 ID 从 'messages' 表中获取消息信息。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE id = %s
|
||||
""",
|
||||
(message_id, )
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting messages_info by id: {error}")
|
||||
|
||||
def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None:
|
||||
"""
|
||||
根据工作流运行 ID 从 'workflow_node_executions' 表中获取节点执行信息。
|
||||
|
||||
Args:
|
||||
workflow_run_id: 目标工作流运行的 ID。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是列名,值是对应的节点执行数据。
|
||||
如果未找到执行信息或发生错误,则返回 None。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s
|
||||
""",
|
||||
(workflow_run_id,)
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return [dict(zip(colnames, row)) for row in result]
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting workflow_node_executions_info: {error}")
|
||||
|
||||
def get_app_conversations(self, appid:str)->list[str] | None:
|
||||
"""
|
||||
根据应用 ID 从 'conversations' 表中获取应用会话信息。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT DISTINCT conversation_id
|
||||
FROM messages
|
||||
WHERE app_id = %s AND invoke_from != 'debugger';
|
||||
""",
|
||||
(appid,)
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return [dict(zip(colnames, row)) for row in result]
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting app_conversations: {error}")
|
||||
|
||||
def get_conversation_messages(self, conversation_id:str)->list[dict] | None:
|
||||
"""
|
||||
根据会话 ID 从 'messages' 表中获取会话消息信息。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE conversation_id = %s AND status = 'normal'
|
||||
""",
|
||||
(conversation_id,)
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return [dict(zip(colnames, row)) for row in result]
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting conversation_messages: {error}")
|
||||
|
||||
def get_message_rating(self, msg_id):
|
||||
"""
|
||||
通过msg_id从message_feedbacks中找到对应的rating。
|
||||
:param msg_id: 消息ID (UUID格式)
|
||||
:return: rating 字符串
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
rating = None
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
# 构建查询语句
|
||||
cursor.execute("""
|
||||
SELECT rating
|
||||
FROM message_feedbacks
|
||||
WHERE message_id = %s
|
||||
""",
|
||||
(msg_id,))
|
||||
# 执行查询
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
rating = row[0]
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting conversation_messages: {error}")
|
||||
return rating
|
||||
|
||||
def get_workflow_run_info(self, workflow_run_id):
|
||||
"""
|
||||
通过msg_id从message_feedbacks中找到对应的rating。
|
||||
:param msg_id: 消息ID (UUID格式)
|
||||
:return: rating 字符串
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
rating = None
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
# 构建查询语句
|
||||
cursor.execute("""
|
||||
SELECT * FROM workflow_runs WHERE id=%s;
|
||||
""",
|
||||
(workflow_run_id,))
|
||||
# 执行查询
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting conversation_messages: {error}")
|
||||
return None
|
||||
|
||||
class DifyTool:
|
||||
|
||||
class PgSql:
|
||||
"""
|
||||
用于连接和操作 PostgreSQL 数据库的类。
|
||||
|
||||
该类封装了数据库连接、关闭连接以及执行特定查询的方法,
|
||||
主要用于从 Dify 应用相关的表中获取数据。
|
||||
"""
|
||||
def __init__(self):
|
||||
self.pg_sql_lock = Lock()
|
||||
"""
|
||||
初始化 PgSql 实例并建立数据库连接。
|
||||
"""
|
||||
self.connection = None
|
||||
self.connect_sql()
|
||||
|
||||
def connect_sql(self):
|
||||
"""
|
||||
连接到 PostgreSQL 数据库。
|
||||
|
||||
使用预定义的凭据连接到 'dify' 数据库。
|
||||
如果连接失败,会抛出异常。
|
||||
"""
|
||||
try:
|
||||
self.DIFY_PG_USER = os.getenv("DIFY_PG_USER")
|
||||
self.DIFY_PG_PASSWORD = os.getenv("DIFY_PG_PASSWORD")
|
||||
self.DIFY_PG_HOST = os.getenv("DIFY_PG_HOST")
|
||||
self.DIFY_PG_PORT = os.getenv("DIFY_PG_PORT")
|
||||
self.DIFY_PG_DATABASE = os.getenv("DIFY_PG_DATABASE")
|
||||
# 连接数据库
|
||||
self.connection = psycopg2.connect(
|
||||
user=self.DIFY_PG_USER,
|
||||
password=self.DIFY_PG_PASSWORD,
|
||||
host=self.DIFY_PG_HOST,
|
||||
port=self.DIFY_PG_PORT,
|
||||
database=self.DIFY_PG_DATABASE
|
||||
)
|
||||
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while connecting to PostgreSQL: {error}")
|
||||
|
||||
def close_connection(self):
|
||||
"""
|
||||
关闭当前的 PostgreSQL 数据库连接。
|
||||
|
||||
如果存在活动的连接,则关闭它。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
self.connection = None
|
||||
|
||||
def get_appinfo(self, appid:str)->dict | None:
|
||||
"""
|
||||
根据应用 ID 从 'apps' 表中获取应用信息。
|
||||
|
||||
Args:
|
||||
appid: 目标应用的 ID。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是列名,值是对应的应用数据。
|
||||
如果未找到应用或发生错误,则返回 None。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM apps WHERE id = %s
|
||||
""",
|
||||
(appid,)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting tenant_id by appid: {error}")
|
||||
|
||||
def get_messages_info(self, appid:str, query:str)->dict | None:
|
||||
"""
|
||||
根据应用 ID 和查询内容从 'messages' 表中获取消息信息。
|
||||
|
||||
Args:
|
||||
appid: 目标应用的 ID。
|
||||
query: 用户查询的具体内容。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是列名,值是对应的消息数据。
|
||||
如果未找到消息或发生错误,则返回 None。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC
|
||||
""",
|
||||
(appid, query)
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting messages_info: {error}")
|
||||
|
||||
def get_messages_info_by_id(self, message_id:str)->dict | None:
|
||||
"""
|
||||
根据消息 ID 从 'messages' 表中获取消息信息。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE id = %s
|
||||
""",
|
||||
(message_id, )
|
||||
)
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting messages_info by id: {error}")
|
||||
|
||||
def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None:
|
||||
"""
|
||||
根据工作流运行 ID 从 'workflow_node_executions' 表中获取节点执行信息。
|
||||
|
||||
Args:
|
||||
workflow_run_id: 目标工作流运行的 ID。
|
||||
|
||||
Returns:
|
||||
一个字典,其中键是列名,值是对应的节点执行数据。
|
||||
如果未找到执行信息或发生错误,则返回 None。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s
|
||||
""",
|
||||
(workflow_run_id,)
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return [dict(zip(colnames, row)) for row in result]
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting workflow_node_executions_info: {error}")
|
||||
|
||||
def get_app_conversations(self, appid:str)->list[str] | None:
|
||||
"""
|
||||
根据应用 ID 从 'conversations' 表中获取应用会话信息。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT DISTINCT conversation_id
|
||||
FROM messages
|
||||
WHERE app_id = %s AND invoke_from != 'debugger';
|
||||
""",
|
||||
(appid,)
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return [dict(zip(colnames, row)) for row in result]
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting app_conversations: {error}")
|
||||
|
||||
def get_conversation_messages(self, conversation_id:str)->list[dict] | None:
|
||||
"""
|
||||
根据会话 ID 从 'messages' 表中获取会话消息信息。
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT * FROM messages WHERE conversation_id = %s AND status = 'normal'
|
||||
""",
|
||||
(conversation_id,)
|
||||
)
|
||||
result = cursor.fetchall()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return [dict(zip(colnames, row)) for row in result]
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting conversation_messages: {error}")
|
||||
|
||||
def get_message_rating(self, msg_id):
|
||||
"""
|
||||
通过msg_id从message_feedbacks中找到对应的rating。
|
||||
:param msg_id: 消息ID (UUID格式)
|
||||
:return: rating 字符串
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
rating = None
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
# 构建查询语句
|
||||
cursor.execute("""
|
||||
SELECT rating
|
||||
FROM message_feedbacks
|
||||
WHERE message_id = %s
|
||||
""",
|
||||
(msg_id,))
|
||||
# 执行查询
|
||||
row = cursor.fetchone()
|
||||
|
||||
if row:
|
||||
rating = row[0]
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting conversation_messages: {error}")
|
||||
return rating
|
||||
|
||||
def get_workflow_run_info(self, workflow_run_id):
|
||||
"""
|
||||
通过msg_id从message_feedbacks中找到对应的rating。
|
||||
:param msg_id: 消息ID (UUID格式)
|
||||
:return: rating 字符串
|
||||
"""
|
||||
with self.pg_sql_lock:
|
||||
rating = None
|
||||
try:
|
||||
with self.connection.cursor() as cursor:
|
||||
# 构建查询语句
|
||||
cursor.execute("""
|
||||
SELECT * FROM workflow_runs WHERE id=%s;
|
||||
""",
|
||||
(workflow_run_id,))
|
||||
# 执行查询
|
||||
result = cursor.fetchone()
|
||||
if result:
|
||||
colnames = [desc[0] for desc in cursor.description]
|
||||
return dict(zip(colnames, result))
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
raise Exception(f"Error while getting conversation_messages: {error}")
|
||||
return None
|
||||
|
||||
"""
|
||||
提供用于获取 Dify 应用调试信息的工具类。
|
||||
|
||||
|
||||
Reference in New Issue
Block a user