Files
QueryRewrite/rag2_0/dify/dify_tool.py
T

531 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import psycopg2
import os
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from rag2_0.dify.dify_client import ChatClient
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
class ContentSource(BaseModel):
score: int = Field(description="相关性分数")
reason: str = Field(description="评分理由")
class PgSql:
"""
用于连接和操作 PostgreSQL 数据库的类。
该类封装了数据库连接、关闭连接以及执行特定查询的方法,
主要用于从 Dify 应用相关的表中获取数据。
"""
def __init__(self):
"""
初始化 PgSql 实例并建立数据库连接。
"""
self.connection = None
self.connect_sql()
def connect_sql(self):
"""
连接到 PostgreSQL 数据库。
使用预定义的凭据连接到 'dify' 数据库。
如果连接失败,会抛出异常。
"""
try:
# 连接数据库
self.connection = psycopg2.connect(
user="postgres",
password="difyai123456",
host="172.20.0.145",
port=5432,
database="dify"
)
except (Exception, psycopg2.Error) as error:
raise Exception(f"Error while connecting to PostgreSQL: {error}")
def close_connection(self):
"""
关闭当前的 PostgreSQL 数据库连接。
如果存在活动的连接,则关闭它。
"""
if self.connection:
self.connection.close()
def get_appinfo(self, appid:str)->dict | None:
"""
根据应用 ID 从 'apps' 表中获取应用信息。
Args:
appid: 目标应用的 ID。
Returns:
一个字典,其中键是列名,值是对应的应用数据。
如果未找到应用或发生错误,则返回 None。
"""
try:
with self.connection.cursor() as cursor:
cursor.execute(
"""
SELECT * FROM apps WHERE id = %s
""",
(appid,)
)
result = cursor.fetchone()
if result:
colnames = [desc[0] for desc in cursor.description]
return dict(zip(colnames, result))
return None
except (Exception, psycopg2.Error) as error:
raise Exception(f"Error while getting tenant_id by appid: {error}")
def get_messages_info(self, appid:str, query:str)->dict | None:
"""
根据应用 ID 和查询内容从 'messages' 表中获取消息信息。
Args:
appid: 目标应用的 ID。
query: 用户查询的具体内容。
Returns:
一个字典,其中键是列名,值是对应的消息数据。
如果未找到消息或发生错误,则返回 None。
"""
try:
with self.connection.cursor() as cursor:
cursor.execute(
"""
SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC
""",
(appid, query)
)
result = cursor.fetchone()
if result:
colnames = [desc[0] for desc in cursor.description]
return dict(zip(colnames, result))
return None
except (Exception, psycopg2.Error) as error:
raise Exception(f"Error while getting messages_info: {error}")
def get_messages_info_by_id(self, message_id:str)->dict | None:
"""
根据消息 ID 从 'messages' 表中获取消息信息。
"""
try:
with self.connection.cursor() as cursor:
cursor.execute(
"""
SELECT * FROM messages WHERE id = %s
""",
(message_id, )
)
result = cursor.fetchone()
if result:
colnames = [desc[0] for desc in cursor.description]
return dict(zip(colnames, result))
return None
except (Exception, psycopg2.Error) as error:
raise Exception(f"Error while getting messages_info by id: {error}")
def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None:
"""
根据工作流运行 ID 从 'workflow_node_executions' 表中获取节点执行信息。
Args:
workflow_run_id: 目标工作流运行的 ID。
Returns:
一个字典,其中键是列名,值是对应的节点执行数据。
如果未找到执行信息或发生错误,则返回 None。
"""
try:
with self.connection.cursor() as cursor:
cursor.execute(
"""
SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s
""",
(workflow_run_id,)
)
result = cursor.fetchall()
if result:
colnames = [desc[0] for desc in cursor.description]
return [dict(zip(colnames, row)) for row in result]
return None
except (Exception, psycopg2.Error) as error:
raise Exception(f"Error while getting workflow_node_executions_info: {error}")
class DifyTool:
"""
提供用于获取 Dify 应用调试信息的工具类。
该类利用 PgSql 类从数据库中检索与特定应用和查询相关的
应用信息、消息详情以及工作流节点执行情况。
"""
@staticmethod
def get_message_debug_info_by_id(message_id:str)->dict | None:
"""
根据消息 ID 从 'messages' 表中获取消息信息。
"""
dify_pgsql = PgSql()
try:
messages_info = dify_pgsql.get_messages_info_by_id(message_id)
if not messages_info:
return None
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
if not workflow_node_executions_info:
return None
return {
"messages_info": messages_info,
"workflow_node_executions_info": workflow_node_executions_info
}
except Exception as e:
raise Exception(f"Error in get_message_debug_info_by_id: {e}")
finally:
dify_pgsql.close_connection()
@staticmethod
def get_message_debug_info_by_query(appid:str, query:str)->dict:
"""
获取指定应用和查询相关的调试信息。
此静态方法会创建一个临时的 PgSql 实例来查询数据库,
然后聚合应用信息、消息信息和工作流节点执行信息。
Args:
appid: 目标应用的 ID。
query: 用户查询的具体内容。
Returns:
一个包含 "appinfo", "messages_info", 和
"workflow_node_executions_info"键的字典,分别对应
查询到的应用数据、消息数据和节点执行数据。
"""
dify_pgsql = PgSql()
try:
appinfo = dify_pgsql.get_appinfo(appid)
if not appinfo:
return None
messages_info = dify_pgsql.get_messages_info(appid, query)
if not messages_info:
return None
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
if not workflow_node_executions_info:
return None
return {
"appinfo": appinfo,
"messages_info": messages_info,
"workflow_node_executions_info": workflow_node_executions_info
}
except Exception as e:
raise Exception(f"Error in get_message_debug_info_by_query: {e}")
finally:
dify_pgsql.close_connection()
class BaseWorkflowChat:
"""
工作流对话基类,封装了与Dify API交互的基本功能
"""
def __init__(self, api_key: str, base_url: str):
"""
初始化工作流对话基类
Args:
api_key: Dify API的密钥
base_url: Dify API的基础URL
"""
self.chat_client = ChatClient(api_key=api_key, base_url=base_url)
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
def create_chat_message(self, query: str):
"""
创建聊天消息
Args:
query: 问题内容
Returns:
tuple: (聊天响应, 消息ID)
"""
try:
response = self.chat_client.create_chat_message(inputs={}, query=query, user="AutoTestDifyChat").json()
return response, response["message_id"]
except Exception as e:
raise e
def calculate_score(self, query: str, content: str) -> int:
"""
使用LLM判断query与content之间的相关性分数
Args:
query (str): 用户问题
content (str): 检索内容
Returns:
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
"""
from rag2_0.tool.ModelTool import OpenAiLLM
try:
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
【评分标准】
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
8-9分:高度相关,核心要素匹配但存在少量信息缺失
6-7分:部分相关,涉及相同主题但存在重要信息缺失
4-5分:弱相关,仅次要信息点匹配
1-3分:完全不相关或信息冲突
【评估维度】
1. 主题一致性:核心主题/意图的匹配程度
2. 内容覆盖度:是否涵盖query的关键要素
3. 信息准确性:是否存在矛盾/错误信息
4. 细节丰富度:是否提供query要求的详细信息
【输出格式】
{{
"score": 评分,
"reason": "简明扼要的评分理由(中文)"
}}
【示例】
query: "新冠疫苗的常见副作用"
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%"
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
现在评估:
query: "{query}"
content: "{content}"
"""
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model = os.getenv("LLM_MODEL_NAME")
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model=model)
response = llm.invoke(user_prompt=prompt, need_retry=True)
# 解析JSON响应
try:
parsed_output = self.content_source_parser.parse(response.content)
return parsed_output.score
except Exception as e:
return -1
except Exception as e:
return -1
def get_retrieve_info(self, query: str, outputs: dict) -> tuple:
"""
获取检索信息并计算分数
Args:
query (str): 用户问题
outputs (dict): 检索输出结果
Returns:
tuple: (检索内容列表, 最高分, 最低分, 平均分)
"""
max_score = 0
min_score = 10
total_score = 0
valid_scores = 0
retrieve_content = []
# 使用线程池并发计算分数
with ThreadPoolExecutor() as executor:
# 创建任务列表
future_to_content = {}
for result in outputs["result"]:
content = result["content"].strip()
future = executor.submit(self.calculate_score, query=query, content=content)
future_to_content[future] = content
# 收集结果
for future in as_completed(future_to_content):
content = future_to_content[future]
score = future.result()
content_title = content.split("\n")[0]
if score != -1:
max_score = max(max_score, score)
min_score = min(min_score, score)
total_score += score
valid_scores += 1
if content_title:
retrieve_content.append(content_title + f"--相关性得分({score}分)")
avg_score = total_score / valid_scores if valid_scores > 0 else 0
return retrieve_content, max_score, min_score, avg_score
class NewWorkflowChat(BaseWorkflowChat):
"""
新工作流对话类,用于调用新工作流发送对话并解析获取相关数据
"""
def process_question(self, query: str) -> dict:
"""
处理问题,获取新工作流的回答和相关信息
Args:
query: 问题内容
Returns:
dict: 包含问题、回答和相关信息的字典
"""
response, message_id = self.create_chat_message(query)
if isinstance(response, str) and response.startswith("error:"):
raise RuntimeError(f"create_chat_message 出错:{response}")
answer = response["answer"]
workflow_info = self.get_workflow_info(query, message_id)
if workflow_info is None:
return None
result = {
"问题": query,
"新流程答案": answer,
"新问题改写": workflow_info["问题改写"],
"新问题分类": workflow_info["问题分类"],
"槽点信息": workflow_info["槽点信息"],
"新检索词条": workflow_info["检索词条"],
"检索内容": workflow_info["检索内容"],
"message_id":message_id
}
return result
def get_workflow_info(self, query: str, message_id: str) -> dict:
"""
获取新工作流的问题分类和检索信息
Args:
query (str): 用户问题
message_id (str): 新工作流的消息ID
Returns:
dict: 包含问题分类结果的字典
"""
retrieve_title = []
retrieve_content = []
max_score = 0
min_score = 0
avg_score = 0
rewrite_query = ""
vertical_classification = ""
sub_classification = ""
slot_info = ""
try:
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
for workflow_node in message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "知识检索结果后处理":
outputs = json.loads(workflow_node["outputs"])
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
retrieve_content = outputs["result"]
elif workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
llm_result_json = json.loads(workflow_node['inputs'])["llm_result"]
json_result = json.loads(llm_result_json)
vertical_classification = json_result['vertical_classification']
sub_classification = json_result['sub_classification']
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
except Exception as e:
raise e
return {
"问题改写": rewrite_query,
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
"检索内容": retrieve_content,
"问题分类": f"{vertical_classification} - {sub_classification}",
"槽点信息": slot_info,
}
class OldWorkFlowChat(BaseWorkflowChat):
"""
旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据
"""
def process_question(self, query: str) -> dict:
"""
处理问题,获取旧工作流的回答和相关信息
Args:
query: 问题内容
Returns:
dict: 包含问题、回答和相关信息的字典
"""
response, message_id = self.create_chat_message(query)
if isinstance(response, str) and response.startswith("error:"):
return None
answer = response["answer"]
workflow_info = self.get_workflow_info(query, message_id)
if workflow_info is None:
return None
result = {
"问题": query,
"旧流程答案": answer,
"旧问题改写": workflow_info["问题改写"],
"旧检索词条": workflow_info["检索词条"],
"检索内容": workflow_info["检索内容"],
"message_id":message_id
}
return result
def get_workflow_info(self, query: str, message_id: str) -> dict:
"""
获取旧工作流的问题改写和检索信息
Args:
query (str): 用户问题
message_id (str): 旧工作流的消息ID
Returns:
dict: 包含问题改写和检索信息的字典
"""
retrieve_title = []
retrieve_content = []
max_score = 0
min_score = 0
avg_score = 0
rewrite_query = ""
try:
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
for workflow_node in message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "知识检索结果后处理":
outputs = json.loads(workflow_node["outputs"])
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
retrieve_content = outputs["result"]
elif workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
except Exception as e:
return None
return {
"问题改写": rewrite_query,
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
"检索内容": retrieve_content,
}
if __name__ == "__main__":
try:
result = DifyTool.get_message_debug_info_by_query("ccf92b97-2789-4a3f-90e0-135a869a37c5", "电力建设计价通软件,导入结算后没有暂列金怎么办?要手动添加么?")
print(result)
except Exception as e:
print(f"执行出错: {e}")