更改项目名称为"rag2_0",更新相关依赖项名称,增强PgSql类的线程安全性,添加锁机制以确保数据库连接的安全关闭,同时优化DifyTool类的资源管理,确保在对象销毁时自动关闭数据库连接。
This commit is contained in:
+1
-1
@@ -1,5 +1,5 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "queryrewrite"
|
name = "rag2_0"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
description = "Add your description here"
|
description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from dify_client.client import ChatClient, CompletionClient, DifyClient
|
|
||||||
|
__all__ = ["ChatClient", "CompletionClient", "DifyClient"]
|
||||||
|
|
||||||
|
from .client import ChatClient, CompletionClient, DifyClient
|
||||||
|
|||||||
+106
-81
@@ -7,12 +7,13 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||||||
from rag2_0.dify.dify_client import ChatClient
|
from rag2_0.dify.dify_client import ChatClient
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from langchain.output_parsers import PydanticOutputParser
|
from langchain.output_parsers import PydanticOutputParser
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
class ContentSource(BaseModel):
|
class ContentSource(BaseModel):
|
||||||
score: int = Field(description="相关性分数")
|
score: int = Field(description="相关性分数")
|
||||||
reason: str = Field(description="评分理由")
|
reason: str = Field(description="评分理由")
|
||||||
|
|
||||||
|
|
||||||
class PgSql:
|
class PgSql:
|
||||||
"""
|
"""
|
||||||
用于连接和操作 PostgreSQL 数据库的类。
|
用于连接和操作 PostgreSQL 数据库的类。
|
||||||
@@ -21,6 +22,7 @@ class PgSql:
|
|||||||
主要用于从 Dify 应用相关的表中获取数据。
|
主要用于从 Dify 应用相关的表中获取数据。
|
||||||
"""
|
"""
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
self.pg_sql_lock = Lock()
|
||||||
"""
|
"""
|
||||||
初始化 PgSql 实例并建立数据库连接。
|
初始化 PgSql 实例并建立数据库连接。
|
||||||
"""
|
"""
|
||||||
@@ -53,8 +55,10 @@ class PgSql:
|
|||||||
|
|
||||||
如果存在活动的连接,则关闭它。
|
如果存在活动的连接,则关闭它。
|
||||||
"""
|
"""
|
||||||
if self.connection:
|
with self.pg_sql_lock:
|
||||||
self.connection.close()
|
if self.connection:
|
||||||
|
self.connection.close()
|
||||||
|
self.connection = None
|
||||||
|
|
||||||
|
|
||||||
def get_appinfo(self, appid:str)->dict | None:
|
def get_appinfo(self, appid:str)->dict | None:
|
||||||
@@ -68,21 +72,22 @@ class PgSql:
|
|||||||
一个字典,其中键是列名,值是对应的应用数据。
|
一个字典,其中键是列名,值是对应的应用数据。
|
||||||
如果未找到应用或发生错误,则返回 None。
|
如果未找到应用或发生错误,则返回 None。
|
||||||
"""
|
"""
|
||||||
try:
|
with self.pg_sql_lock:
|
||||||
with self.connection.cursor() as cursor:
|
try:
|
||||||
cursor.execute(
|
with self.connection.cursor() as cursor:
|
||||||
"""
|
cursor.execute(
|
||||||
SELECT * FROM apps WHERE id = %s
|
"""
|
||||||
""",
|
SELECT * FROM apps WHERE id = %s
|
||||||
(appid,)
|
""",
|
||||||
)
|
(appid,)
|
||||||
result = cursor.fetchone()
|
)
|
||||||
if result:
|
result = cursor.fetchone()
|
||||||
colnames = [desc[0] for desc in cursor.description]
|
if result:
|
||||||
return dict(zip(colnames, result))
|
colnames = [desc[0] for desc in cursor.description]
|
||||||
return None
|
return dict(zip(colnames, result))
|
||||||
except (Exception, psycopg2.Error) as error:
|
return None
|
||||||
raise Exception(f"Error while getting tenant_id by appid: {error}")
|
except (Exception, psycopg2.Error) as error:
|
||||||
|
raise Exception(f"Error while getting tenant_id by appid: {error}")
|
||||||
|
|
||||||
|
|
||||||
def get_messages_info(self, appid:str, query:str)->dict | None:
|
def get_messages_info(self, appid:str, query:str)->dict | None:
|
||||||
@@ -97,41 +102,43 @@ class PgSql:
|
|||||||
一个字典,其中键是列名,值是对应的消息数据。
|
一个字典,其中键是列名,值是对应的消息数据。
|
||||||
如果未找到消息或发生错误,则返回 None。
|
如果未找到消息或发生错误,则返回 None。
|
||||||
"""
|
"""
|
||||||
try:
|
with self.pg_sql_lock:
|
||||||
with self.connection.cursor() as cursor:
|
try:
|
||||||
cursor.execute(
|
with self.connection.cursor() as cursor:
|
||||||
"""
|
cursor.execute(
|
||||||
SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC
|
"""
|
||||||
""",
|
SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC
|
||||||
(appid, query)
|
""",
|
||||||
)
|
(appid, query)
|
||||||
result = cursor.fetchone()
|
)
|
||||||
if result:
|
result = cursor.fetchone()
|
||||||
colnames = [desc[0] for desc in cursor.description]
|
if result:
|
||||||
return dict(zip(colnames, result))
|
colnames = [desc[0] for desc in cursor.description]
|
||||||
return None
|
return dict(zip(colnames, result))
|
||||||
except (Exception, psycopg2.Error) as error:
|
return None
|
||||||
raise Exception(f"Error while getting messages_info: {error}")
|
except (Exception, psycopg2.Error) as error:
|
||||||
|
raise Exception(f"Error while getting messages_info: {error}")
|
||||||
|
|
||||||
def get_messages_info_by_id(self, message_id:str)->dict | None:
|
def get_messages_info_by_id(self, message_id:str)->dict | None:
|
||||||
"""
|
"""
|
||||||
根据消息 ID 从 'messages' 表中获取消息信息。
|
根据消息 ID 从 'messages' 表中获取消息信息。
|
||||||
"""
|
"""
|
||||||
try:
|
with self.pg_sql_lock:
|
||||||
with self.connection.cursor() as cursor:
|
try:
|
||||||
cursor.execute(
|
with self.connection.cursor() as cursor:
|
||||||
"""
|
cursor.execute(
|
||||||
SELECT * FROM messages WHERE id = %s
|
"""
|
||||||
""",
|
SELECT * FROM messages WHERE id = %s
|
||||||
(message_id, )
|
""",
|
||||||
)
|
(message_id, )
|
||||||
result = cursor.fetchone()
|
)
|
||||||
if result:
|
result = cursor.fetchone()
|
||||||
colnames = [desc[0] for desc in cursor.description]
|
if result:
|
||||||
return dict(zip(colnames, result))
|
colnames = [desc[0] for desc in cursor.description]
|
||||||
return None
|
return dict(zip(colnames, result))
|
||||||
except (Exception, psycopg2.Error) as error:
|
return None
|
||||||
raise Exception(f"Error while getting messages_info by id: {error}")
|
except (Exception, psycopg2.Error) as error:
|
||||||
|
raise Exception(f"Error while getting messages_info by id: {error}")
|
||||||
|
|
||||||
def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None:
|
def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None:
|
||||||
"""
|
"""
|
||||||
@@ -144,21 +151,22 @@ class PgSql:
|
|||||||
一个字典,其中键是列名,值是对应的节点执行数据。
|
一个字典,其中键是列名,值是对应的节点执行数据。
|
||||||
如果未找到执行信息或发生错误,则返回 None。
|
如果未找到执行信息或发生错误,则返回 None。
|
||||||
"""
|
"""
|
||||||
try:
|
with self.pg_sql_lock:
|
||||||
with self.connection.cursor() as cursor:
|
try:
|
||||||
cursor.execute(
|
with self.connection.cursor() as cursor:
|
||||||
"""
|
cursor.execute(
|
||||||
SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s
|
"""
|
||||||
""",
|
SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s
|
||||||
(workflow_run_id,)
|
""",
|
||||||
)
|
(workflow_run_id,)
|
||||||
|
)
|
||||||
result = cursor.fetchall()
|
result = cursor.fetchall()
|
||||||
if result:
|
if result:
|
||||||
colnames = [desc[0] for desc in cursor.description]
|
colnames = [desc[0] for desc in cursor.description]
|
||||||
return [dict(zip(colnames, row)) for row in result]
|
return [dict(zip(colnames, row)) for row in result]
|
||||||
return None
|
return None
|
||||||
except (Exception, psycopg2.Error) as error:
|
except (Exception, psycopg2.Error) as error:
|
||||||
raise Exception(f"Error while getting workflow_node_executions_info: {error}")
|
raise Exception(f"Error while getting workflow_node_executions_info: {error}")
|
||||||
|
|
||||||
class DifyTool:
|
class DifyTool:
|
||||||
"""
|
"""
|
||||||
@@ -167,17 +175,30 @@ class DifyTool:
|
|||||||
该类利用 PgSql 类从数据库中检索与特定应用和查询相关的
|
该类利用 PgSql 类从数据库中检索与特定应用和查询相关的
|
||||||
应用信息、消息详情以及工作流节点执行情况。
|
应用信息、消息详情以及工作流节点执行情况。
|
||||||
"""
|
"""
|
||||||
@staticmethod
|
|
||||||
def get_message_debug_info_by_id(message_id:str)->dict | None:
|
def __init__(self):
|
||||||
|
self.dify_pgsql = PgSql()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
析构函数,在对象被销毁时自动关闭数据库连接。
|
||||||
|
确保在对象生命周期结束时释放数据库资源。
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.dify_pgsql.close_connection()
|
||||||
|
except Exception as e:
|
||||||
|
# 析构函数中的异常不应该传播,所以这里只是简单记录
|
||||||
|
print(f"关闭数据库连接时出错: {e}")
|
||||||
|
|
||||||
|
def get_message_debug_info_by_id(self, message_id:str)->dict | None:
|
||||||
"""
|
"""
|
||||||
根据消息 ID 从 'messages' 表中获取消息信息。
|
根据消息 ID 从 'messages' 表中获取消息信息。
|
||||||
"""
|
"""
|
||||||
dify_pgsql = PgSql()
|
|
||||||
try:
|
try:
|
||||||
messages_info = dify_pgsql.get_messages_info_by_id(message_id)
|
messages_info = self.dify_pgsql.get_messages_info_by_id(message_id)
|
||||||
if not messages_info:
|
if not messages_info:
|
||||||
return None
|
return None
|
||||||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
workflow_node_executions_info = self.dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||||
if not workflow_node_executions_info:
|
if not workflow_node_executions_info:
|
||||||
return None
|
return None
|
||||||
return {
|
return {
|
||||||
@@ -186,12 +207,8 @@ class DifyTool:
|
|||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Error in get_message_debug_info_by_id: {e}")
|
raise Exception(f"Error in get_message_debug_info_by_id: {e}")
|
||||||
finally:
|
|
||||||
dify_pgsql.close_connection()
|
|
||||||
|
|
||||||
|
def get_message_debug_info_by_query(self, appid:str, query:str)->dict:
|
||||||
@staticmethod
|
|
||||||
def get_message_debug_info_by_query(appid:str, query:str)->dict:
|
|
||||||
"""
|
"""
|
||||||
获取指定应用和查询相关的调试信息。
|
获取指定应用和查询相关的调试信息。
|
||||||
|
|
||||||
@@ -207,15 +224,14 @@ class DifyTool:
|
|||||||
"workflow_node_executions_info"键的字典,分别对应
|
"workflow_node_executions_info"键的字典,分别对应
|
||||||
查询到的应用数据、消息数据和节点执行数据。
|
查询到的应用数据、消息数据和节点执行数据。
|
||||||
"""
|
"""
|
||||||
dify_pgsql = PgSql()
|
|
||||||
try:
|
try:
|
||||||
appinfo = dify_pgsql.get_appinfo(appid)
|
appinfo = self.dify_pgsql.get_appinfo(appid)
|
||||||
if not appinfo:
|
if not appinfo:
|
||||||
return None
|
return None
|
||||||
messages_info = dify_pgsql.get_messages_info(appid, query)
|
messages_info = self.dify_pgsql.get_messages_info(appid, query)
|
||||||
if not messages_info:
|
if not messages_info:
|
||||||
return None
|
return None
|
||||||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
workflow_node_executions_info = self.dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||||
if not workflow_node_executions_info:
|
if not workflow_node_executions_info:
|
||||||
return None
|
return None
|
||||||
return {
|
return {
|
||||||
@@ -225,8 +241,6 @@ class DifyTool:
|
|||||||
}
|
}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Error in get_message_debug_info_by_query: {e}")
|
raise Exception(f"Error in get_message_debug_info_by_query: {e}")
|
||||||
finally:
|
|
||||||
dify_pgsql.close_connection()
|
|
||||||
|
|
||||||
class BaseWorkflowChat:
|
class BaseWorkflowChat:
|
||||||
"""
|
"""
|
||||||
@@ -242,6 +256,14 @@ class BaseWorkflowChat:
|
|||||||
"""
|
"""
|
||||||
self.chat_client = ChatClient(api_key=api_key, base_url=base_url)
|
self.chat_client = ChatClient(api_key=api_key, base_url=base_url)
|
||||||
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
|
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
|
||||||
|
self.dify_tool = DifyTool()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
析构函数,在对象被销毁时自动关闭数据库连接。
|
||||||
|
确保在对象生命周期结束时释放数据库资源。
|
||||||
|
"""
|
||||||
|
self.dify_tool.close_connection()
|
||||||
|
|
||||||
def create_chat_message(self, query: str):
|
def create_chat_message(self, query: str):
|
||||||
"""
|
"""
|
||||||
@@ -369,6 +391,9 @@ class NewWorkflowChat(BaseWorkflowChat):
|
|||||||
"""
|
"""
|
||||||
新工作流对话类,用于调用新工作流发送对话并解析获取相关数据
|
新工作流对话类,用于调用新工作流发送对话并解析获取相关数据
|
||||||
"""
|
"""
|
||||||
|
def __init__(self, api_key: str, base_url: str):
|
||||||
|
super().__init__(api_key, base_url)
|
||||||
|
|
||||||
def process_question(self, query: str) -> dict:
|
def process_question(self, query: str) -> dict:
|
||||||
"""
|
"""
|
||||||
处理问题,获取新工作流的回答和相关信息
|
处理问题,获取新工作流的回答和相关信息
|
||||||
@@ -425,7 +450,7 @@ class NewWorkflowChat(BaseWorkflowChat):
|
|||||||
reranker_sorce=[]
|
reranker_sorce=[]
|
||||||
try:
|
try:
|
||||||
# 先取出重排得分
|
# 先取出重排得分
|
||||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
message_info = self.dify_tool.get_message_debug_info_by_id(message_id=message_id)
|
||||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||||
if workflow_node["title"] == "软件知识检索聚合":
|
if workflow_node["title"] == "软件知识检索聚合":
|
||||||
retrieve_outputs = json.loads(workflow_node["inputs"])["result"]
|
retrieve_outputs = json.loads(workflow_node["inputs"])["result"]
|
||||||
@@ -470,6 +495,10 @@ class OldWorkFlowChat(BaseWorkflowChat):
|
|||||||
"""
|
"""
|
||||||
旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据
|
旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, api_key: str, base_url: str):
|
||||||
|
super().__init__(api_key, base_url)
|
||||||
|
|
||||||
def process_question(self, query: str) -> dict:
|
def process_question(self, query: str) -> dict:
|
||||||
"""
|
"""
|
||||||
处理问题,获取旧工作流的回答和相关信息
|
处理问题,获取旧工作流的回答和相关信息
|
||||||
@@ -520,7 +549,7 @@ class OldWorkFlowChat(BaseWorkflowChat):
|
|||||||
rewrite_query = ""
|
rewrite_query = ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
message_info = self.dify_tool.get_message_debug_info_by_id(message_id=message_id)
|
||||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||||
if workflow_node["title"] == "知识检索结果后处理":
|
if workflow_node["title"] == "知识检索结果后处理":
|
||||||
outputs = json.loads(workflow_node["outputs"])
|
outputs = json.loads(workflow_node["outputs"])
|
||||||
@@ -538,8 +567,4 @@ class OldWorkFlowChat(BaseWorkflowChat):
|
|||||||
}
|
}
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
try:
|
pass
|
||||||
result = DifyTool.get_message_debug_info_by_query("ccf92b97-2789-4a3f-90e0-135a869a37c5", "电力建设计价通软件,导入结算后没有暂列金怎么办?要手动添加么?")
|
|
||||||
print(result)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"执行出错: {e}")
|
|
||||||
|
|||||||
@@ -69,6 +69,14 @@ class DifyComparisonTester:
|
|||||||
else:
|
else:
|
||||||
self.wiki_excel = None
|
self.wiki_excel = None
|
||||||
|
|
||||||
|
self.dify_tool = DifyTool()
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
"""
|
||||||
|
析构函数,在对象被销毁时自动关闭数据库连接。
|
||||||
|
确保在对象生命周期结束时释放数据库资源。
|
||||||
|
"""
|
||||||
|
self.dify_tool.close_connection()
|
||||||
|
|
||||||
def get_llm(self):
|
def get_llm(self):
|
||||||
api_key = os.getenv("OPENAI_API_KEY")
|
api_key = os.getenv("OPENAI_API_KEY")
|
||||||
@@ -381,7 +389,7 @@ content: "{content}"
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 使用DifyTool直接获取消息信息
|
# 使用DifyTool直接获取消息信息
|
||||||
new_message_info = DifyTool.get_message_debug_info_by_id(message_id=new_message_id)
|
new_message_info = self.dify_tool.get_message_debug_info_by_id(message_id=new_message_id)
|
||||||
|
|
||||||
# 初始化变量
|
# 初始化变量
|
||||||
retrieve_title = []
|
retrieve_title = []
|
||||||
@@ -428,7 +436,7 @@ content: "{content}"
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
# 使用DifyTool直接获取消息信息
|
# 使用DifyTool直接获取消息信息
|
||||||
old_message_info = DifyTool.get_message_debug_info_by_id(message_id=old_message_id)
|
old_message_info = self.dify_tool.get_message_debug_info_by_id(message_id=old_message_id)
|
||||||
|
|
||||||
# 初始化变量
|
# 初始化变量
|
||||||
retrieve_title = []
|
retrieve_title = []
|
||||||
|
|||||||
Reference in New Issue
Block a user