531 lines
19 KiB
Python
531 lines
19 KiB
Python
#!/usr/bin/env python
|
||
# -*- coding: utf-8 -*-
|
||
import psycopg2
|
||
import os
|
||
import json
|
||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||
from rag2_0.dify.dify_client import ChatClient
|
||
from pydantic import BaseModel, Field
|
||
from langchain.output_parsers import PydanticOutputParser
|
||
|
||
|
||
class ContentSource(BaseModel):
|
||
score: int = Field(description="相关性分数")
|
||
reason: str = Field(description="评分理由")
|
||
|
||
class PgSql:
|
||
"""
|
||
用于连接和操作 PostgreSQL 数据库的类。
|
||
|
||
该类封装了数据库连接、关闭连接以及执行特定查询的方法,
|
||
主要用于从 Dify 应用相关的表中获取数据。
|
||
"""
|
||
def __init__(self):
|
||
"""
|
||
初始化 PgSql 实例并建立数据库连接。
|
||
"""
|
||
self.connection = None
|
||
self.connect_sql()
|
||
|
||
def connect_sql(self):
|
||
"""
|
||
连接到 PostgreSQL 数据库。
|
||
|
||
使用预定义的凭据连接到 'dify' 数据库。
|
||
如果连接失败,会抛出异常。
|
||
"""
|
||
try:
|
||
# 连接数据库
|
||
self.connection = psycopg2.connect(
|
||
user="postgres",
|
||
password="difyai123456",
|
||
host="172.20.0.145",
|
||
port=5432,
|
||
database="dify"
|
||
)
|
||
|
||
except (Exception, psycopg2.Error) as error:
|
||
raise Exception(f"Error while connecting to PostgreSQL: {error}")
|
||
|
||
def close_connection(self):
|
||
"""
|
||
关闭当前的 PostgreSQL 数据库连接。
|
||
|
||
如果存在活动的连接,则关闭它。
|
||
"""
|
||
if self.connection:
|
||
self.connection.close()
|
||
|
||
|
||
def get_appinfo(self, appid:str)->dict | None:
|
||
"""
|
||
根据应用 ID 从 'apps' 表中获取应用信息。
|
||
|
||
Args:
|
||
appid: 目标应用的 ID。
|
||
|
||
Returns:
|
||
一个字典,其中键是列名,值是对应的应用数据。
|
||
如果未找到应用或发生错误,则返回 None。
|
||
"""
|
||
try:
|
||
with self.connection.cursor() as cursor:
|
||
cursor.execute(
|
||
"""
|
||
SELECT * FROM apps WHERE id = %s
|
||
""",
|
||
(appid,)
|
||
)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
colnames = [desc[0] for desc in cursor.description]
|
||
return dict(zip(colnames, result))
|
||
return None
|
||
except (Exception, psycopg2.Error) as error:
|
||
raise Exception(f"Error while getting tenant_id by appid: {error}")
|
||
|
||
|
||
def get_messages_info(self, appid:str, query:str)->dict | None:
|
||
"""
|
||
根据应用 ID 和查询内容从 'messages' 表中获取消息信息。
|
||
|
||
Args:
|
||
appid: 目标应用的 ID。
|
||
query: 用户查询的具体内容。
|
||
|
||
Returns:
|
||
一个字典,其中键是列名,值是对应的消息数据。
|
||
如果未找到消息或发生错误,则返回 None。
|
||
"""
|
||
try:
|
||
with self.connection.cursor() as cursor:
|
||
cursor.execute(
|
||
"""
|
||
SELECT * FROM messages WHERE app_id = %s AND query = %s ORDER BY created_at DESC
|
||
""",
|
||
(appid, query)
|
||
)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
colnames = [desc[0] for desc in cursor.description]
|
||
return dict(zip(colnames, result))
|
||
return None
|
||
except (Exception, psycopg2.Error) as error:
|
||
raise Exception(f"Error while getting messages_info: {error}")
|
||
|
||
def get_messages_info_by_id(self, message_id:str)->dict | None:
|
||
"""
|
||
根据消息 ID 从 'messages' 表中获取消息信息。
|
||
"""
|
||
try:
|
||
with self.connection.cursor() as cursor:
|
||
cursor.execute(
|
||
"""
|
||
SELECT * FROM messages WHERE id = %s
|
||
""",
|
||
(message_id, )
|
||
)
|
||
result = cursor.fetchone()
|
||
if result:
|
||
colnames = [desc[0] for desc in cursor.description]
|
||
return dict(zip(colnames, result))
|
||
return None
|
||
except (Exception, psycopg2.Error) as error:
|
||
raise Exception(f"Error while getting messages_info by id: {error}")
|
||
|
||
def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None:
|
||
"""
|
||
根据工作流运行 ID 从 'workflow_node_executions' 表中获取节点执行信息。
|
||
|
||
Args:
|
||
workflow_run_id: 目标工作流运行的 ID。
|
||
|
||
Returns:
|
||
一个字典,其中键是列名,值是对应的节点执行数据。
|
||
如果未找到执行信息或发生错误,则返回 None。
|
||
"""
|
||
try:
|
||
with self.connection.cursor() as cursor:
|
||
cursor.execute(
|
||
"""
|
||
SELECT * FROM workflow_node_executions WHERE workflow_run_id = %s
|
||
""",
|
||
(workflow_run_id,)
|
||
)
|
||
result = cursor.fetchall()
|
||
if result:
|
||
colnames = [desc[0] for desc in cursor.description]
|
||
return [dict(zip(colnames, row)) for row in result]
|
||
return None
|
||
except (Exception, psycopg2.Error) as error:
|
||
raise Exception(f"Error while getting workflow_node_executions_info: {error}")
|
||
|
||
class DifyTool:
|
||
"""
|
||
提供用于获取 Dify 应用调试信息的工具类。
|
||
|
||
该类利用 PgSql 类从数据库中检索与特定应用和查询相关的
|
||
应用信息、消息详情以及工作流节点执行情况。
|
||
"""
|
||
@staticmethod
|
||
def get_message_debug_info_by_id(message_id:str)->dict | None:
|
||
"""
|
||
根据消息 ID 从 'messages' 表中获取消息信息。
|
||
"""
|
||
dify_pgsql = PgSql()
|
||
try:
|
||
messages_info = dify_pgsql.get_messages_info_by_id(message_id)
|
||
if not messages_info:
|
||
return None
|
||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||
if not workflow_node_executions_info:
|
||
return None
|
||
return {
|
||
"messages_info": messages_info,
|
||
"workflow_node_executions_info": workflow_node_executions_info
|
||
}
|
||
except Exception as e:
|
||
raise Exception(f"Error in get_message_debug_info_by_id: {e}")
|
||
finally:
|
||
dify_pgsql.close_connection()
|
||
|
||
|
||
@staticmethod
|
||
def get_message_debug_info_by_query(appid:str, query:str)->dict:
|
||
"""
|
||
获取指定应用和查询相关的调试信息。
|
||
|
||
此静态方法会创建一个临时的 PgSql 实例来查询数据库,
|
||
然后聚合应用信息、消息信息和工作流节点执行信息。
|
||
|
||
Args:
|
||
appid: 目标应用的 ID。
|
||
query: 用户查询的具体内容。
|
||
|
||
Returns:
|
||
一个包含 "appinfo", "messages_info", 和
|
||
"workflow_node_executions_info"键的字典,分别对应
|
||
查询到的应用数据、消息数据和节点执行数据。
|
||
"""
|
||
dify_pgsql = PgSql()
|
||
try:
|
||
appinfo = dify_pgsql.get_appinfo(appid)
|
||
if not appinfo:
|
||
return None
|
||
messages_info = dify_pgsql.get_messages_info(appid, query)
|
||
if not messages_info:
|
||
return None
|
||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||
if not workflow_node_executions_info:
|
||
return None
|
||
return {
|
||
"appinfo": appinfo,
|
||
"messages_info": messages_info,
|
||
"workflow_node_executions_info": workflow_node_executions_info
|
||
}
|
||
except Exception as e:
|
||
raise Exception(f"Error in get_message_debug_info_by_query: {e}")
|
||
finally:
|
||
dify_pgsql.close_connection()
|
||
|
||
class BaseWorkflowChat:
|
||
"""
|
||
工作流对话基类,封装了与Dify API交互的基本功能
|
||
"""
|
||
def __init__(self, api_key: str, base_url: str):
|
||
"""
|
||
初始化工作流对话基类
|
||
|
||
Args:
|
||
api_key: Dify API的密钥
|
||
base_url: Dify API的基础URL
|
||
"""
|
||
self.chat_client = ChatClient(api_key=api_key, base_url=base_url)
|
||
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
|
||
|
||
def create_chat_message(self, query: str):
|
||
"""
|
||
创建聊天消息
|
||
|
||
Args:
|
||
query: 问题内容
|
||
|
||
Returns:
|
||
tuple: (聊天响应, 消息ID)
|
||
"""
|
||
try:
|
||
response = self.chat_client.create_chat_message(inputs={}, query=query, user="AutoTestDifyChat").json()
|
||
return response, response["message_id"]
|
||
except Exception as e:
|
||
raise e
|
||
|
||
def calculate_score(self, query: str, content: str) -> int:
|
||
"""
|
||
使用LLM判断query与content之间的相关性分数
|
||
|
||
Args:
|
||
query (str): 用户问题
|
||
content (str): 检索内容
|
||
|
||
Returns:
|
||
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
|
||
"""
|
||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||
|
||
try:
|
||
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
|
||
|
||
【评分标准】
|
||
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
|
||
8-9分:高度相关,核心要素匹配但存在少量信息缺失
|
||
6-7分:部分相关,涉及相同主题但存在重要信息缺失
|
||
4-5分:弱相关,仅次要信息点匹配
|
||
1-3分:完全不相关或信息冲突
|
||
|
||
【评估维度】
|
||
1. 主题一致性:核心主题/意图的匹配程度
|
||
2. 内容覆盖度:是否涵盖query的关键要素
|
||
3. 信息准确性:是否存在矛盾/错误信息
|
||
4. 细节丰富度:是否提供query要求的详细信息
|
||
|
||
【输出格式】
|
||
{{
|
||
"score": 评分,
|
||
"reason": "简明扼要的评分理由(中文)"
|
||
}}
|
||
|
||
【示例】
|
||
query: "新冠疫苗的常见副作用"
|
||
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)"
|
||
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
|
||
|
||
现在评估:
|
||
query: "{query}"
|
||
content: "{content}"
|
||
"""
|
||
api_key = os.getenv("OPENAI_API_KEY")
|
||
base_url = os.getenv("OPENAI_API_BASE")
|
||
model = os.getenv("LLM_MODEL_NAME")
|
||
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model=model)
|
||
response = llm.invoke(user_prompt=prompt, need_retry=True)
|
||
|
||
# 解析JSON响应
|
||
try:
|
||
parsed_output = self.content_source_parser.parse(response.content)
|
||
return parsed_output.score
|
||
except Exception as e:
|
||
return -1
|
||
except Exception as e:
|
||
return -1
|
||
|
||
def get_retrieve_info(self, query: str, outputs: dict) -> tuple:
|
||
"""
|
||
获取检索信息并计算分数
|
||
|
||
Args:
|
||
query (str): 用户问题
|
||
outputs (dict): 检索输出结果
|
||
|
||
Returns:
|
||
tuple: (检索内容列表, 最高分, 最低分, 平均分)
|
||
"""
|
||
max_score = 0
|
||
min_score = 10
|
||
total_score = 0
|
||
valid_scores = 0
|
||
retrieve_content = []
|
||
|
||
# 使用线程池并发计算分数
|
||
with ThreadPoolExecutor() as executor:
|
||
# 创建任务列表
|
||
future_to_content = {}
|
||
for result in outputs["result"]:
|
||
content = result["content"].strip()
|
||
future = executor.submit(self.calculate_score, query=query, content=content)
|
||
future_to_content[future] = content
|
||
|
||
# 收集结果
|
||
for future in as_completed(future_to_content):
|
||
content = future_to_content[future]
|
||
score = future.result()
|
||
content_title = content.split("\n")[0]
|
||
|
||
if score != -1:
|
||
max_score = max(max_score, score)
|
||
min_score = min(min_score, score)
|
||
total_score += score
|
||
valid_scores += 1
|
||
|
||
if content_title:
|
||
retrieve_content.append(content_title + f"--相关性得分({score}分)")
|
||
|
||
avg_score = total_score / valid_scores if valid_scores > 0 else 0
|
||
return retrieve_content, max_score, min_score, avg_score
|
||
|
||
|
||
class NewWorkflowChat(BaseWorkflowChat):
|
||
"""
|
||
新工作流对话类,用于调用新工作流发送对话并解析获取相关数据
|
||
"""
|
||
def process_question(self, query: str) -> dict:
|
||
"""
|
||
处理问题,获取新工作流的回答和相关信息
|
||
|
||
Args:
|
||
query: 问题内容
|
||
|
||
Returns:
|
||
dict: 包含问题、回答和相关信息的字典
|
||
"""
|
||
response, message_id = self.create_chat_message(query)
|
||
|
||
if isinstance(response, str) and response.startswith("error:"):
|
||
raise RuntimeError(f"create_chat_message 出错:{response}")
|
||
|
||
answer = response["answer"]
|
||
workflow_info = self.get_workflow_info(query, message_id)
|
||
|
||
if workflow_info is None:
|
||
return None
|
||
|
||
result = {
|
||
"问题": query,
|
||
"新流程答案": answer,
|
||
"新问题改写": workflow_info["问题改写"],
|
||
"新问题分类": workflow_info["问题分类"],
|
||
"槽点信息": workflow_info["槽点信息"],
|
||
"新检索词条": workflow_info["检索词条"],
|
||
"检索内容": workflow_info["检索内容"],
|
||
"message_id":message_id
|
||
}
|
||
|
||
return result
|
||
|
||
def get_workflow_info(self, query: str, message_id: str) -> dict:
|
||
"""
|
||
获取新工作流的问题分类和检索信息
|
||
|
||
Args:
|
||
query (str): 用户问题
|
||
message_id (str): 新工作流的消息ID
|
||
|
||
Returns:
|
||
dict: 包含问题分类结果的字典
|
||
"""
|
||
retrieve_title = []
|
||
retrieve_content = []
|
||
max_score = 0
|
||
min_score = 0
|
||
avg_score = 0
|
||
rewrite_query = ""
|
||
vertical_classification = ""
|
||
sub_classification = ""
|
||
slot_info = ""
|
||
|
||
try:
|
||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||
if workflow_node["title"] == "知识检索结果后处理":
|
||
outputs = json.loads(workflow_node["outputs"])
|
||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||
retrieve_content = outputs["result"]
|
||
elif workflow_node["title"] == "问题优化结果解析":
|
||
outputs = json.loads(workflow_node["outputs"])
|
||
rewrite_query = outputs["optimize_query"]
|
||
llm_result_json = json.loads(workflow_node['inputs'])["llm_result"]
|
||
json_result = json.loads(llm_result_json)
|
||
vertical_classification = json_result['vertical_classification']
|
||
sub_classification = json_result['sub_classification']
|
||
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
|
||
except Exception as e:
|
||
raise e
|
||
|
||
return {
|
||
"问题改写": rewrite_query,
|
||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||
"检索内容": retrieve_content,
|
||
"问题分类": f"{vertical_classification} - {sub_classification}",
|
||
"槽点信息": slot_info,
|
||
|
||
}
|
||
|
||
class OldWorkFlowChat(BaseWorkflowChat):
|
||
"""
|
||
旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据
|
||
"""
|
||
def process_question(self, query: str) -> dict:
|
||
"""
|
||
处理问题,获取旧工作流的回答和相关信息
|
||
|
||
Args:
|
||
query: 问题内容
|
||
|
||
Returns:
|
||
dict: 包含问题、回答和相关信息的字典
|
||
"""
|
||
response, message_id = self.create_chat_message(query)
|
||
|
||
if isinstance(response, str) and response.startswith("error:"):
|
||
return None
|
||
|
||
answer = response["answer"]
|
||
workflow_info = self.get_workflow_info(query, message_id)
|
||
|
||
if workflow_info is None:
|
||
return None
|
||
|
||
result = {
|
||
"问题": query,
|
||
"旧流程答案": answer,
|
||
"旧问题改写": workflow_info["问题改写"],
|
||
"旧检索词条": workflow_info["检索词条"],
|
||
"检索内容": workflow_info["检索内容"],
|
||
"message_id":message_id
|
||
}
|
||
|
||
return result
|
||
|
||
def get_workflow_info(self, query: str, message_id: str) -> dict:
|
||
"""
|
||
获取旧工作流的问题改写和检索信息
|
||
|
||
Args:
|
||
query (str): 用户问题
|
||
message_id (str): 旧工作流的消息ID
|
||
|
||
Returns:
|
||
dict: 包含问题改写和检索信息的字典
|
||
"""
|
||
retrieve_title = []
|
||
retrieve_content = []
|
||
max_score = 0
|
||
min_score = 0
|
||
avg_score = 0
|
||
rewrite_query = ""
|
||
|
||
try:
|
||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||
if workflow_node["title"] == "知识检索结果后处理":
|
||
outputs = json.loads(workflow_node["outputs"])
|
||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||
retrieve_content = outputs["result"]
|
||
elif workflow_node["title"] == "问题优化结果解析":
|
||
outputs = json.loads(workflow_node["outputs"])
|
||
rewrite_query = outputs["optimize_query"]
|
||
except Exception as e:
|
||
return None
|
||
|
||
return {
|
||
"问题改写": rewrite_query,
|
||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||
"检索内容": retrieve_content,
|
||
}
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
result = DifyTool.get_message_debug_info_by_query("ccf92b97-2789-4a3f-90e0-135a869a37c5", "电力建设计价通软件,导入结算后没有暂列金怎么办?要手动添加么?")
|
||
print(result)
|
||
except Exception as e:
|
||
print(f"执行出错: {e}")
|