更新环境配置,添加gevent和gunicorn依赖;新增chat_dify_by_workorder.py文件以处理工单对话逻辑;优化PgSql类中的异常处理,确保连接失败时抛出异常;改进意图识别API,使用单例模式管理意图识别器实例,增强线程安全性;新增workflow_chat.py文件以支持新工作流对话功能。
This commit is contained in:
@@ -0,0 +1,151 @@
|
||||
from rag2_0.dify.workflow_chat import NewWorkflowChat
|
||||
import pandas as pd
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
import concurrent.futures
|
||||
|
||||
|
||||
class ChatDifyByWorkorder:
|
||||
|
||||
def __init__(self, api_key=None, base_url="https://api.dify.ai/v1") -> None:
|
||||
"""
|
||||
初始化ChatDifyByWorkorder类
|
||||
|
||||
Args:
|
||||
api_key: Dify API密钥,默认为None
|
||||
base_url: Dify API的基础URL,默认为"https://api.dify.ai/v1"
|
||||
"""
|
||||
baseurl = "http://172.20.0.145/v1"
|
||||
new_workflow_api_key = "app-qxsSybCs7ABiKlC1JabTYVn6"
|
||||
self.new_chat = NewWorkflowChat(api_key=new_workflow_api_key, base_url=baseurl)
|
||||
self.new_chat_answer = NewWorkflowChat(api_key=new_workflow_api_key, base_url=baseurl)
|
||||
|
||||
|
||||
def get_soft_name(self, row) -> str:
|
||||
if "博微配网计价通D3" in row["产品线"]:
|
||||
return "博微配网计价通D3"
|
||||
elif "博微电力建设计价通软件" in row["产品线"]:
|
||||
return "电力建设计价通软件"
|
||||
elif "新能源系列" in row["产品线"] and "博微新型储能电站建设计价通C1软件" in row["产品名称"]:
|
||||
return "储能C1软件"
|
||||
elif "博微西藏计价通Z1" in row["产品线"]:
|
||||
return "西藏计价通Z1"
|
||||
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-概预算" in row["产品名称"]:
|
||||
return "技改检修工程计价通T1"
|
||||
elif "博微技改检修计价通T1软件" in row["产品线"] and "技改检修计价通T1软件-清单" in row["产品名称"]:
|
||||
return "检修清单计价通T1"
|
||||
return ""
|
||||
|
||||
def process_query(self, q:str) -> dict:
|
||||
"""
|
||||
发送问题并获取回答及相关工作流信息
|
||||
|
||||
Args:
|
||||
q: 用户问题
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和工作流信息的字典
|
||||
"""
|
||||
retry_count = 0
|
||||
max_retries = 2
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 发送问题获取回答和消息ID
|
||||
result = self.new_chat.process_question(q)
|
||||
return result
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count <= max_retries:
|
||||
continue
|
||||
else:
|
||||
raise e
|
||||
|
||||
def process_answer(self, q:str) -> dict:
|
||||
"""
|
||||
发送问题并获取回答及相关工作流信息
|
||||
|
||||
Args:
|
||||
q: 用户问题
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和工作流信息的字典
|
||||
"""
|
||||
retry_count = 0
|
||||
max_retries = 2
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 发送问题获取回答和消息ID
|
||||
result = self.new_chat_answer.process_question(q)
|
||||
return result
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count <= max_retries:
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
def process_row(self, row):
|
||||
"""处理单行数据"""
|
||||
soft_name = self.get_soft_name(row=row)
|
||||
if soft_name == "":
|
||||
return None
|
||||
|
||||
# 使用线程池并发执行查询
|
||||
with ThreadPoolExecutor() as executor:
|
||||
try:
|
||||
# 提交两个任务并获取Future对象
|
||||
query_future = executor.submit(self.process_query, q=f"{soft_name},{row['客户问题']}")
|
||||
answer_future = executor.submit(self.process_answer, q=f"{soft_name},{row['解决方案']}")
|
||||
|
||||
# 获取结果
|
||||
query_result = query_future.result()
|
||||
answer_result = answer_future.result()
|
||||
except Exception as e:
|
||||
print(f"处理工单 {row.get('工单编号', '未知')} 时发生错误: {str(e)}")
|
||||
return None
|
||||
|
||||
worker_id = str(row["工单编号"])
|
||||
if query_result is None or answer_result is None:
|
||||
print("处理对话出现错误")
|
||||
return None
|
||||
|
||||
worker_order_info = {
|
||||
"工单编号": worker_id,
|
||||
"用户问题": row['客户问题'],
|
||||
"解决方案": row['解决方案'],
|
||||
"AI回答": query_result["新流程答案"],
|
||||
"用户问题检索到的词条": query_result["新检索词条"],
|
||||
"解决方案检索到的词条": answer_result["新检索词条"],
|
||||
}
|
||||
return worker_order_info
|
||||
|
||||
def run(self, excel_path:str):
|
||||
df_data = pd.read_excel(excel_path)
|
||||
list_worker_order_info = []
|
||||
|
||||
# 创建进度条
|
||||
with tqdm(total=len(df_data), desc="处理工单") as pbar:
|
||||
# 创建线程池,最大并发数可以根据需要调整
|
||||
with ThreadPoolExecutor(max_workers=5) as executor:
|
||||
# 提交所有任务
|
||||
future_to_row = {executor.submit(self.process_row, row): idx for idx, row in df_data.iterrows()}
|
||||
|
||||
# 处理完成的任务
|
||||
for future in concurrent.futures.as_completed(future_to_row):
|
||||
result = future.result()
|
||||
if result is not None:
|
||||
list_worker_order_info.append(result)
|
||||
pbar.update(1)
|
||||
|
||||
return list_worker_order_info
|
||||
|
||||
|
||||
|
||||
if __name__=="__main__":
|
||||
worker_chat = ChatDifyByWorkorder()
|
||||
result = worker_chat.run(excel_path="data/excel/工单记录_均衡提取2000条.xlsx")
|
||||
# 可以选择保存结果到Excel
|
||||
if result:
|
||||
pd.DataFrame(result).to_excel("data/excel/工单处理结果.xlsx", index=False)
|
||||
+46
-33
@@ -23,7 +23,7 @@ class PgSql:
|
||||
连接到 PostgreSQL 数据库。
|
||||
|
||||
使用预定义的凭据连接到 'dify' 数据库。
|
||||
如果连接失败,会打印错误信息。
|
||||
如果连接失败,会抛出异常。
|
||||
"""
|
||||
try:
|
||||
# 连接数据库
|
||||
@@ -36,17 +36,16 @@ class PgSql:
|
||||
)
|
||||
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
print("Error while connecting to PostgreSQL", error)
|
||||
raise Exception(f"Error while connecting to PostgreSQL: {error}")
|
||||
|
||||
def close_connection(self):
|
||||
"""
|
||||
关闭当前的 PostgreSQL 数据库连接。
|
||||
|
||||
如果存在活动的连接,则关闭它并打印确认信息。
|
||||
如果存在活动的连接,则关闭它。
|
||||
"""
|
||||
if self.connection:
|
||||
self.connection.close()
|
||||
print("PostgreSQL connection is closed")
|
||||
|
||||
|
||||
def get_appinfo(self, appid:str)->dict | None:
|
||||
@@ -74,7 +73,7 @@ class PgSql:
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
print("Error while getting tenant_id by appid", error)
|
||||
raise Exception(f"Error while getting tenant_id by appid: {error}")
|
||||
|
||||
|
||||
def get_messages_info(self, appid:str, query:str)->dict | None:
|
||||
@@ -103,7 +102,7 @@ class PgSql:
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
print("Error while getting messages_info", error)
|
||||
raise Exception(f"Error while getting messages_info: {error}")
|
||||
|
||||
def get_messages_info_by_id(self, message_id:str)->dict | None:
|
||||
"""
|
||||
@@ -123,7 +122,7 @@ class PgSql:
|
||||
return dict(zip(colnames, result))
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
print("Error while getting messages_info", error)
|
||||
raise Exception(f"Error while getting messages_info by id: {error}")
|
||||
|
||||
def get_workflow_node_executions_info(self, workflow_run_id:str)->list[dict] | None:
|
||||
"""
|
||||
@@ -150,7 +149,7 @@ class PgSql:
|
||||
return [dict(zip(colnames, row)) for row in result]
|
||||
return None
|
||||
except (Exception, psycopg2.Error) as error:
|
||||
print("Error while getting workflow_node_executions_info", error)
|
||||
raise Exception(f"Error while getting workflow_node_executions_info: {error}")
|
||||
|
||||
class DifyTool:
|
||||
"""
|
||||
@@ -165,16 +164,21 @@ class DifyTool:
|
||||
根据消息 ID 从 'messages' 表中获取消息信息。
|
||||
"""
|
||||
dify_pgsql = PgSql()
|
||||
messages_info = dify_pgsql.get_messages_info_by_id(message_id)
|
||||
if not messages_info:
|
||||
return None
|
||||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||
if not workflow_node_executions_info:
|
||||
return None
|
||||
return {
|
||||
"messages_info": messages_info,
|
||||
"workflow_node_executions_info": workflow_node_executions_info
|
||||
}
|
||||
try:
|
||||
messages_info = dify_pgsql.get_messages_info_by_id(message_id)
|
||||
if not messages_info:
|
||||
return None
|
||||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||
if not workflow_node_executions_info:
|
||||
return None
|
||||
return {
|
||||
"messages_info": messages_info,
|
||||
"workflow_node_executions_info": workflow_node_executions_info
|
||||
}
|
||||
except Exception as e:
|
||||
raise Exception(f"Error in get_message_debug_info_by_id: {e}")
|
||||
finally:
|
||||
dify_pgsql.close_connection()
|
||||
|
||||
|
||||
@staticmethod
|
||||
@@ -195,21 +199,30 @@ class DifyTool:
|
||||
查询到的应用数据、消息数据和节点执行数据。
|
||||
"""
|
||||
dify_pgsql = PgSql()
|
||||
appinfo = dify_pgsql.get_appinfo(appid)
|
||||
if not appinfo:
|
||||
return None
|
||||
messages_info = dify_pgsql.get_messages_info(appid, query)
|
||||
if not messages_info:
|
||||
return None
|
||||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||
if not workflow_node_executions_info:
|
||||
return None
|
||||
return {
|
||||
"appinfo": appinfo,
|
||||
"messages_info": messages_info,
|
||||
"workflow_node_executions_info": workflow_node_executions_info
|
||||
}
|
||||
try:
|
||||
appinfo = dify_pgsql.get_appinfo(appid)
|
||||
if not appinfo:
|
||||
return None
|
||||
messages_info = dify_pgsql.get_messages_info(appid, query)
|
||||
if not messages_info:
|
||||
return None
|
||||
workflow_node_executions_info = dify_pgsql.get_workflow_node_executions_info(messages_info['workflow_run_id'])
|
||||
if not workflow_node_executions_info:
|
||||
return None
|
||||
return {
|
||||
"appinfo": appinfo,
|
||||
"messages_info": messages_info,
|
||||
"workflow_node_executions_info": workflow_node_executions_info
|
||||
}
|
||||
except Exception as e:
|
||||
raise Exception(f"Error in get_message_debug_info_by_query: {e}")
|
||||
finally:
|
||||
dify_pgsql.close_connection()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(DifyTool.get_message_debug_info_by_query("ccf92b97-2789-4a3f-90e0-135a869a37c5", "电力建设计价通软件,导入结算后没有暂列金怎么办?要手动添加么?"))
|
||||
try:
|
||||
result = DifyTool.get_message_debug_info_by_query("ccf92b97-2789-4a3f-90e0-135a869a37c5", "电力建设计价通软件,导入结算后没有暂列金怎么办?要手动添加么?")
|
||||
print(result)
|
||||
except Exception as e:
|
||||
print(f"执行出错: {e}")
|
||||
|
||||
@@ -4,16 +4,32 @@ from dotenv import load_dotenv
|
||||
from rag2_0.intent_recognition import IntentRecognizer
|
||||
import json
|
||||
import time
|
||||
import threading
|
||||
import datetime
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
app = Flask(__name__)
|
||||
|
||||
# 初始化意图识别器
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
# 创建线程锁,用于保护共享资源
|
||||
recognizer_lock = threading.Lock()
|
||||
|
||||
# 使用单例模式创建意图识别器
|
||||
class RecognizerSingleton:
|
||||
_instance = None
|
||||
_lock = threading.Lock()
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
if cls._instance is None:
|
||||
with cls._lock:
|
||||
if cls._instance is None:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
cls._instance = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
return cls._instance
|
||||
|
||||
@app.route('/intent_recognize', methods=['POST'])
|
||||
def intent_recognize():
|
||||
@@ -22,10 +38,16 @@ def intent_recognize():
|
||||
query = data.get('query')
|
||||
if not query:
|
||||
return Response(json.dumps({"error": "缺少query参数"}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=400)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 获取单例实例并使用线程锁保护关键操作
|
||||
recognizer = RecognizerSingleton.get_instance()
|
||||
result = recognizer.process_query_with_slots(query)
|
||||
|
||||
end_time = time.time()
|
||||
print(f"意图识别耗时: {end_time - start_time:.2f}秒")
|
||||
current_time = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S %z")
|
||||
print(f"[{current_time}] [{os.getpid()}] [INFO] 意图识别耗时: {end_time - start_time:.2f}秒")
|
||||
|
||||
# 提取分类信息
|
||||
classification = result["classification"]
|
||||
@@ -63,7 +85,10 @@ def intent_recognize():
|
||||
}
|
||||
return Response(json.dumps(response_result, ensure_ascii=False), content_type='application/json; charset=utf-8')
|
||||
except Exception as e:
|
||||
print(f"意图识别出错: {str(e)}")
|
||||
return Response(json.dumps({"error": str(e)}, ensure_ascii=False), content_type='application/json; charset=utf-8', status=500)
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.run(host="0.0.0.0", port=8001)
|
||||
# 开发环境使用Flask内置服务器
|
||||
# 生产环境使用gunicorn支持高并发 poetry run gunicorn -w 10 -k gevent -b 0.0.0.0:8001 rag2_0.dify.intent_recognition_api:app
|
||||
app.run(host="0.0.0.0", port=8001, threaded=True)
|
||||
@@ -0,0 +1,310 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from rag2_0.dify.dify_client import ChatClient, DifyClient
|
||||
from rag2_0.dify.dify_tool import DifyTool
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from threading import Lock
|
||||
|
||||
class ContentSource(BaseModel):
|
||||
score: int = Field(description="相关性分数")
|
||||
reason: str = Field(description="评分理由")
|
||||
|
||||
class BaseWorkflowChat:
|
||||
"""
|
||||
工作流对话基类,封装了与Dify API交互的基本功能
|
||||
"""
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
"""
|
||||
初始化工作流对话基类
|
||||
|
||||
Args:
|
||||
api_key: Dify API的密钥
|
||||
base_url: Dify API的基础URL
|
||||
"""
|
||||
self.chat_client = ChatClient(api_key=api_key, base_url=base_url)
|
||||
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
|
||||
|
||||
def create_chat_message(self, query: str):
|
||||
"""
|
||||
创建聊天消息
|
||||
|
||||
Args:
|
||||
query: 问题内容
|
||||
|
||||
Returns:
|
||||
tuple: (聊天响应, 消息ID)
|
||||
"""
|
||||
try:
|
||||
response = self.chat_client.create_chat_message(inputs={}, query=query, user="AutoTestDifyChat").json()
|
||||
return response, response["message_id"]
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def calculate_score(self, query: str, content: str) -> int:
|
||||
"""
|
||||
使用LLM判断query与content之间的相关性分数
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
content (str): 检索内容
|
||||
|
||||
Returns:
|
||||
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
|
||||
"""
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
|
||||
try:
|
||||
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
|
||||
|
||||
【评分标准】
|
||||
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
|
||||
8-9分:高度相关,核心要素匹配但存在少量信息缺失
|
||||
6-7分:部分相关,涉及相同主题但存在重要信息缺失
|
||||
4-5分:弱相关,仅次要信息点匹配
|
||||
1-3分:完全不相关或信息冲突
|
||||
|
||||
【评估维度】
|
||||
1. 主题一致性:核心主题/意图的匹配程度
|
||||
2. 内容覆盖度:是否涵盖query的关键要素
|
||||
3. 信息准确性:是否存在矛盾/错误信息
|
||||
4. 细节丰富度:是否提供query要求的详细信息
|
||||
|
||||
【输出格式】
|
||||
{{
|
||||
"score": 评分,
|
||||
"reason": "简明扼要的评分理由(中文)"
|
||||
}}
|
||||
|
||||
【示例】
|
||||
query: "新冠疫苗的常见副作用"
|
||||
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)"
|
||||
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
|
||||
|
||||
现在评估:
|
||||
query: "{query}"
|
||||
content: "{content}"
|
||||
"""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model = os.getenv("LLM_MODEL_NAME")
|
||||
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model=model)
|
||||
response = llm.invoke(user_prompt=prompt, need_retry=True)
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
parsed_output = self.content_source_parser.parse(response.content)
|
||||
return parsed_output.score
|
||||
except Exception as e:
|
||||
return -1
|
||||
except Exception as e:
|
||||
return -1
|
||||
|
||||
def get_retrieve_info(self, query: str, outputs: dict) -> tuple:
|
||||
"""
|
||||
获取检索信息并计算分数
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
outputs (dict): 检索输出结果
|
||||
|
||||
Returns:
|
||||
tuple: (检索内容列表, 最高分, 最低分, 平均分)
|
||||
"""
|
||||
max_score = 0
|
||||
min_score = 10
|
||||
total_score = 0
|
||||
valid_scores = 0
|
||||
retrieve_content = []
|
||||
|
||||
# 使用线程池并发计算分数
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# 创建任务列表
|
||||
future_to_content = {}
|
||||
for result in outputs["result"]:
|
||||
content = result["content"].strip()
|
||||
future = executor.submit(self.calculate_score, query=query, content=content)
|
||||
future_to_content[future] = content
|
||||
|
||||
# 收集结果
|
||||
for future in as_completed(future_to_content):
|
||||
content = future_to_content[future]
|
||||
score = future.result()
|
||||
content_title = content.split("\n")[0]
|
||||
|
||||
if score != -1:
|
||||
max_score = max(max_score, score)
|
||||
min_score = min(min_score, score)
|
||||
total_score += score
|
||||
valid_scores += 1
|
||||
|
||||
if content_title:
|
||||
retrieve_content.append(content_title + f"--相关性得分({score}分)")
|
||||
|
||||
avg_score = total_score / valid_scores if valid_scores > 0 else 0
|
||||
return retrieve_content, max_score, min_score, avg_score
|
||||
|
||||
|
||||
class NewWorkflowChat(BaseWorkflowChat):
|
||||
"""
|
||||
新工作流对话类,用于调用新工作流发送对话并解析获取相关数据
|
||||
"""
|
||||
def process_question(self, query: str) -> dict:
|
||||
"""
|
||||
处理问题,获取新工作流的回答和相关信息
|
||||
|
||||
Args:
|
||||
query: 问题内容
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和相关信息的字典
|
||||
"""
|
||||
response, message_id = self.create_chat_message(query)
|
||||
|
||||
if isinstance(response, str) and response.startswith("error:"):
|
||||
raise RuntimeError(f"create_chat_message 出错:{response}")
|
||||
|
||||
answer = response["answer"]
|
||||
workflow_info = self.get_workflow_info(query, message_id)
|
||||
|
||||
if workflow_info is None:
|
||||
return None
|
||||
|
||||
result = {
|
||||
"问题": query,
|
||||
"新流程答案": answer,
|
||||
"新问题改写": workflow_info["问题改写"],
|
||||
"新问题分类": workflow_info["问题分类"],
|
||||
"槽点信息": workflow_info["槽点信息"],
|
||||
"新检索词条": workflow_info["检索词条"],
|
||||
"检索内容": workflow_info["检索内容"],
|
||||
"message_id":message_id
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_workflow_info(self, query: str, message_id: str) -> dict:
|
||||
"""
|
||||
获取新工作流的问题分类和检索信息
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
message_id (str): 新工作流的消息ID
|
||||
|
||||
Returns:
|
||||
dict: 包含问题分类结果的字典
|
||||
"""
|
||||
retrieve_title = []
|
||||
retrieve_content = []
|
||||
max_score = 0
|
||||
min_score = 0
|
||||
avg_score = 0
|
||||
rewrite_query = ""
|
||||
vertical_classification = ""
|
||||
sub_classification = ""
|
||||
slot_info = ""
|
||||
|
||||
try:
|
||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||
retrieve_content = outputs["result"]
|
||||
elif workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
llm_result_json = json.loads(workflow_node['inputs'])["llm_result"]
|
||||
json_result = json.loads(llm_result_json)
|
||||
vertical_classification = json_result['vertical_classification']
|
||||
sub_classification = json_result['sub_classification']
|
||||
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return {
|
||||
"问题改写": rewrite_query,
|
||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||
"检索内容": retrieve_content,
|
||||
"问题分类": f"{vertical_classification} - {sub_classification}",
|
||||
"槽点信息": slot_info,
|
||||
|
||||
}
|
||||
|
||||
|
||||
class OldWorkFlowChat(BaseWorkflowChat):
|
||||
"""
|
||||
旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据
|
||||
"""
|
||||
def process_question(self, query: str) -> dict:
|
||||
"""
|
||||
处理问题,获取旧工作流的回答和相关信息
|
||||
|
||||
Args:
|
||||
query: 问题内容
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和相关信息的字典
|
||||
"""
|
||||
response, message_id = self.create_chat_message(query)
|
||||
|
||||
if isinstance(response, str) and response.startswith("error:"):
|
||||
return None
|
||||
|
||||
answer = response["answer"]
|
||||
workflow_info = self.get_workflow_info(query, message_id)
|
||||
|
||||
if workflow_info is None:
|
||||
return None
|
||||
|
||||
result = {
|
||||
"问题": query,
|
||||
"旧流程答案": answer,
|
||||
"旧问题改写": workflow_info["问题改写"],
|
||||
"旧检索词条": workflow_info["检索词条"],
|
||||
"检索内容": workflow_info["检索内容"],
|
||||
"message_id":message_id
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_workflow_info(self, query: str, message_id: str) -> dict:
|
||||
"""
|
||||
获取旧工作流的问题改写和检索信息
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
message_id (str): 旧工作流的消息ID
|
||||
|
||||
Returns:
|
||||
dict: 包含问题改写和检索信息的字典
|
||||
"""
|
||||
retrieve_title = []
|
||||
retrieve_content = []
|
||||
max_score = 0
|
||||
min_score = 0
|
||||
avg_score = 0
|
||||
rewrite_query = ""
|
||||
|
||||
try:
|
||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||
retrieve_content = outputs["result"]
|
||||
elif workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
return {
|
||||
"问题改写": rewrite_query,
|
||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||
"检索内容": retrieve_content,
|
||||
}
|
||||
Reference in New Issue
Block a user