新增对话处理功能,优化意图识别逻辑,添加结果保存至Excel的功能,更新依赖项以支持新的数据库驱动和ORM,重构代码以提高可读性和维护性,删除冗余文件以简化项目结构。
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from rag2_0.dify.workflow_chat import NewWorkflowChat
|
||||
from rag2_0.dify.dify_tool import NewWorkflowChat
|
||||
import pandas as pd
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from tqdm import tqdm
|
||||
|
||||
+304
-2
@@ -1,8 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
import psycopg2
|
||||
from psycopg2 import sql
|
||||
import os
|
||||
import json
|
||||
from datetime import timezone, timedelta
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from rag2_0.dify.dify_client import ChatClient
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
|
||||
|
||||
class ContentSource(BaseModel):
|
||||
score: int = Field(description="相关性分数")
|
||||
reason: str = Field(description="评分理由")
|
||||
|
||||
class PgSql:
|
||||
"""
|
||||
@@ -219,6 +228,299 @@ class DifyTool:
|
||||
finally:
|
||||
dify_pgsql.close_connection()
|
||||
|
||||
class BaseWorkflowChat:
|
||||
"""
|
||||
工作流对话基类,封装了与Dify API交互的基本功能
|
||||
"""
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
"""
|
||||
初始化工作流对话基类
|
||||
|
||||
Args:
|
||||
api_key: Dify API的密钥
|
||||
base_url: Dify API的基础URL
|
||||
"""
|
||||
self.chat_client = ChatClient(api_key=api_key, base_url=base_url)
|
||||
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
|
||||
|
||||
def create_chat_message(self, query: str):
|
||||
"""
|
||||
创建聊天消息
|
||||
|
||||
Args:
|
||||
query: 问题内容
|
||||
|
||||
Returns:
|
||||
tuple: (聊天响应, 消息ID)
|
||||
"""
|
||||
try:
|
||||
response = self.chat_client.create_chat_message(inputs={}, query=query, user="AutoTestDifyChat").json()
|
||||
return response, response["message_id"]
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def calculate_score(self, query: str, content: str) -> int:
|
||||
"""
|
||||
使用LLM判断query与content之间的相关性分数
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
content (str): 检索内容
|
||||
|
||||
Returns:
|
||||
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
|
||||
"""
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
|
||||
try:
|
||||
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
|
||||
|
||||
【评分标准】
|
||||
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
|
||||
8-9分:高度相关,核心要素匹配但存在少量信息缺失
|
||||
6-7分:部分相关,涉及相同主题但存在重要信息缺失
|
||||
4-5分:弱相关,仅次要信息点匹配
|
||||
1-3分:完全不相关或信息冲突
|
||||
|
||||
【评估维度】
|
||||
1. 主题一致性:核心主题/意图的匹配程度
|
||||
2. 内容覆盖度:是否涵盖query的关键要素
|
||||
3. 信息准确性:是否存在矛盾/错误信息
|
||||
4. 细节丰富度:是否提供query要求的详细信息
|
||||
|
||||
【输出格式】
|
||||
{{
|
||||
"score": 评分,
|
||||
"reason": "简明扼要的评分理由(中文)"
|
||||
}}
|
||||
|
||||
【示例】
|
||||
query: "新冠疫苗的常见副作用"
|
||||
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)"
|
||||
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
|
||||
|
||||
现在评估:
|
||||
query: "{query}"
|
||||
content: "{content}"
|
||||
"""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model = os.getenv("LLM_MODEL_NAME")
|
||||
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model=model)
|
||||
response = llm.invoke(user_prompt=prompt, need_retry=True)
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
parsed_output = self.content_source_parser.parse(response.content)
|
||||
return parsed_output.score
|
||||
except Exception as e:
|
||||
return -1
|
||||
except Exception as e:
|
||||
return -1
|
||||
|
||||
def get_retrieve_info(self, query: str, outputs: dict) -> tuple:
|
||||
"""
|
||||
获取检索信息并计算分数
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
outputs (dict): 检索输出结果
|
||||
|
||||
Returns:
|
||||
tuple: (检索内容列表, 最高分, 最低分, 平均分)
|
||||
"""
|
||||
max_score = 0
|
||||
min_score = 10
|
||||
total_score = 0
|
||||
valid_scores = 0
|
||||
retrieve_content = []
|
||||
|
||||
# 使用线程池并发计算分数
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# 创建任务列表
|
||||
future_to_content = {}
|
||||
for result in outputs["result"]:
|
||||
content = result["content"].strip()
|
||||
future = executor.submit(self.calculate_score, query=query, content=content)
|
||||
future_to_content[future] = content
|
||||
|
||||
# 收集结果
|
||||
for future in as_completed(future_to_content):
|
||||
content = future_to_content[future]
|
||||
score = future.result()
|
||||
content_title = content.split("\n")[0]
|
||||
|
||||
if score != -1:
|
||||
max_score = max(max_score, score)
|
||||
min_score = min(min_score, score)
|
||||
total_score += score
|
||||
valid_scores += 1
|
||||
|
||||
if content_title:
|
||||
retrieve_content.append(content_title + f"--相关性得分({score}分)")
|
||||
|
||||
avg_score = total_score / valid_scores if valid_scores > 0 else 0
|
||||
return retrieve_content, max_score, min_score, avg_score
|
||||
|
||||
|
||||
class NewWorkflowChat(BaseWorkflowChat):
|
||||
"""
|
||||
新工作流对话类,用于调用新工作流发送对话并解析获取相关数据
|
||||
"""
|
||||
def process_question(self, query: str) -> dict:
|
||||
"""
|
||||
处理问题,获取新工作流的回答和相关信息
|
||||
|
||||
Args:
|
||||
query: 问题内容
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和相关信息的字典
|
||||
"""
|
||||
response, message_id = self.create_chat_message(query)
|
||||
|
||||
if isinstance(response, str) and response.startswith("error:"):
|
||||
raise RuntimeError(f"create_chat_message 出错:{response}")
|
||||
|
||||
answer = response["answer"]
|
||||
workflow_info = self.get_workflow_info(query, message_id)
|
||||
|
||||
if workflow_info is None:
|
||||
return None
|
||||
|
||||
result = {
|
||||
"问题": query,
|
||||
"新流程答案": answer,
|
||||
"新问题改写": workflow_info["问题改写"],
|
||||
"新问题分类": workflow_info["问题分类"],
|
||||
"槽点信息": workflow_info["槽点信息"],
|
||||
"新检索词条": workflow_info["检索词条"],
|
||||
"检索内容": workflow_info["检索内容"],
|
||||
"message_id":message_id
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_workflow_info(self, query: str, message_id: str) -> dict:
|
||||
"""
|
||||
获取新工作流的问题分类和检索信息
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
message_id (str): 新工作流的消息ID
|
||||
|
||||
Returns:
|
||||
dict: 包含问题分类结果的字典
|
||||
"""
|
||||
retrieve_title = []
|
||||
retrieve_content = []
|
||||
max_score = 0
|
||||
min_score = 0
|
||||
avg_score = 0
|
||||
rewrite_query = ""
|
||||
vertical_classification = ""
|
||||
sub_classification = ""
|
||||
slot_info = ""
|
||||
|
||||
try:
|
||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||
retrieve_content = outputs["result"]
|
||||
elif workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
llm_result_json = json.loads(workflow_node['inputs'])["llm_result"]
|
||||
json_result = json.loads(llm_result_json)
|
||||
vertical_classification = json_result['vertical_classification']
|
||||
sub_classification = json_result['sub_classification']
|
||||
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return {
|
||||
"问题改写": rewrite_query,
|
||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||
"检索内容": retrieve_content,
|
||||
"问题分类": f"{vertical_classification} - {sub_classification}",
|
||||
"槽点信息": slot_info,
|
||||
|
||||
}
|
||||
|
||||
class OldWorkFlowChat(BaseWorkflowChat):
|
||||
"""
|
||||
旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据
|
||||
"""
|
||||
def process_question(self, query: str) -> dict:
|
||||
"""
|
||||
处理问题,获取旧工作流的回答和相关信息
|
||||
|
||||
Args:
|
||||
query: 问题内容
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和相关信息的字典
|
||||
"""
|
||||
response, message_id = self.create_chat_message(query)
|
||||
|
||||
if isinstance(response, str) and response.startswith("error:"):
|
||||
return None
|
||||
|
||||
answer = response["answer"]
|
||||
workflow_info = self.get_workflow_info(query, message_id)
|
||||
|
||||
if workflow_info is None:
|
||||
return None
|
||||
|
||||
result = {
|
||||
"问题": query,
|
||||
"旧流程答案": answer,
|
||||
"旧问题改写": workflow_info["问题改写"],
|
||||
"旧检索词条": workflow_info["检索词条"],
|
||||
"检索内容": workflow_info["检索内容"],
|
||||
"message_id":message_id
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_workflow_info(self, query: str, message_id: str) -> dict:
|
||||
"""
|
||||
获取旧工作流的问题改写和检索信息
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
message_id (str): 旧工作流的消息ID
|
||||
|
||||
Returns:
|
||||
dict: 包含问题改写和检索信息的字典
|
||||
"""
|
||||
retrieve_title = []
|
||||
retrieve_content = []
|
||||
max_score = 0
|
||||
min_score = 0
|
||||
avg_score = 0
|
||||
rewrite_query = ""
|
||||
|
||||
try:
|
||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||
retrieve_content = outputs["result"]
|
||||
elif workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
return {
|
||||
"问题改写": rewrite_query,
|
||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||
"检索内容": retrieve_content,
|
||||
}
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from rag2_0.dify.dify_client import ChatClient, DifyClient
|
||||
from rag2_0.dify.dify_client import DifyClient
|
||||
from rag2_0.dify.dify_tool import NewWorkflowChat, OldWorkFlowChat
|
||||
import pandas as pd
|
||||
# 使用线程池并发执行
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
@@ -44,8 +45,9 @@ class DifyComparisonTester:
|
||||
max_workers: 最大工作线程数
|
||||
"""
|
||||
self.excel_path = excel_path
|
||||
self.old_chat = ChatClient(api_key=old_workflow_api_key, base_url=baseurl)
|
||||
self.new_chat = ChatClient(api_key=new_workflow_api_key, base_url=baseurl)
|
||||
# 使用NewWorkflowChat和OldWorkFlowChat代替ChatClient
|
||||
self.old_chat = OldWorkFlowChat(api_key=old_workflow_api_key, base_url=baseurl)
|
||||
self.new_chat = NewWorkflowChat(api_key=new_workflow_api_key, base_url=baseurl)
|
||||
|
||||
# 评判相关参数
|
||||
self.output_path = output_path or os.path.join(os.path.dirname(self.excel_path), "dify问答_综合评判结果.xlsx")
|
||||
@@ -78,13 +80,13 @@ class DifyComparisonTester:
|
||||
"""
|
||||
def get_old_answer():
|
||||
try:
|
||||
return self.old_chat.create_chat_message(inputs={}, query=q, user="AutoTestDifyChat").json()
|
||||
return self.old_chat.process_question(query=q)
|
||||
except Exception as e:
|
||||
return f"error: {str(e)}"
|
||||
|
||||
def get_new_answer():
|
||||
try:
|
||||
return self.new_chat.create_chat_message(inputs={}, query=q, user="AutoTestDifyChat").json()
|
||||
return self.new_chat.process_question(query=q)
|
||||
except Exception as e:
|
||||
return f"error: {str(e)}"
|
||||
|
||||
@@ -95,14 +97,15 @@ class DifyComparisonTester:
|
||||
try:
|
||||
old_result = future_old.result()
|
||||
new_result = future_new.result()
|
||||
old_message_id = old_result["message_id"]
|
||||
new_message_id = new_result["message_id"]
|
||||
|
||||
if isinstance(old_result, str) and old_result.startswith("error:"):
|
||||
return None, None
|
||||
if isinstance(new_result, str) and new_result.startswith("error:"):
|
||||
return None, None
|
||||
|
||||
old_answer = old_result["answer"]
|
||||
new_answer = new_result["answer"]
|
||||
except Exception as e:
|
||||
return None, None, None
|
||||
return {"问题": q, "旧流程答案": old_answer, "新流程答案": new_answer}, old_message_id, new_message_id
|
||||
return future_old, future_new
|
||||
|
||||
def find_wiki_link(self, query) -> str | None:
|
||||
"""
|
||||
@@ -407,22 +410,24 @@ content: "{content}"
|
||||
Returns:
|
||||
dict: 包含问题分类结果的字典
|
||||
"""
|
||||
retrieve_title=[]
|
||||
retrieve_content=[]
|
||||
max_score=0
|
||||
min_score=0
|
||||
avg_score=0
|
||||
rewrite_query=""
|
||||
vertical_classification=""
|
||||
sub_classification=""
|
||||
slot_info=""
|
||||
try:
|
||||
# 使用DifyTool直接获取消息信息
|
||||
new_message_info = DifyTool.get_message_debug_info_by_id(message_id=new_message_id)
|
||||
|
||||
# 初始化变量
|
||||
retrieve_title = []
|
||||
retrieve_content = []
|
||||
rewrite_query = ""
|
||||
vertical_classification = ""
|
||||
sub_classification = ""
|
||||
slot_info = ""
|
||||
|
||||
# 解析工作流节点信息
|
||||
for workflow_node in new_message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||
retrieve_content=outputs["result"]
|
||||
retrieve_content = outputs["result"]
|
||||
elif workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
@@ -430,20 +435,21 @@ content: "{content}"
|
||||
json_result = json.loads(llm_result_json)
|
||||
vertical_classification = json_result['vertical_classification']
|
||||
sub_classification = json_result['sub_classification']
|
||||
slot_info=json.dumps(json_result["slot_filling"],ensure_ascii=False,indent=2)
|
||||
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
return {
|
||||
"问题改写": rewrite_query,
|
||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||
"检索内容": retrieve_content,
|
||||
"问题分类": f"{vertical_classification} - {sub_classification}",
|
||||
"槽点信息":slot_info
|
||||
"槽点信息": slot_info
|
||||
}
|
||||
|
||||
def get_old_workflow_info(self, query:str, old_message_id:str) -> dict:
|
||||
"""
|
||||
获取新流程的问题分类
|
||||
获取旧流程的问题分类
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
@@ -452,24 +458,27 @@ content: "{content}"
|
||||
Returns:
|
||||
dict: 包含问题分类结果的字典
|
||||
"""
|
||||
retrieve_title=[]
|
||||
retrieve_content=[]
|
||||
max_score=0
|
||||
min_score=0
|
||||
avg_score=0
|
||||
rewrite_query=""
|
||||
try:
|
||||
# 使用DifyTool直接获取消息信息
|
||||
old_message_info = DifyTool.get_message_debug_info_by_id(message_id=old_message_id)
|
||||
|
||||
# 初始化变量
|
||||
retrieve_title = []
|
||||
retrieve_content = []
|
||||
rewrite_query = ""
|
||||
|
||||
# 解析工作流节点信息
|
||||
for workflow_node in old_message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||
retrieve_content=outputs["result"]
|
||||
retrieve_content = outputs["result"]
|
||||
elif workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
return {
|
||||
"问题改写": rewrite_query,
|
||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||
@@ -512,13 +521,13 @@ content: "{content}"
|
||||
dict: 包含问题、回答和评判结果的字典
|
||||
"""
|
||||
# 获取基本的问题和回答
|
||||
basic_result, old_message_id, new_message_id = self.process_question(q)
|
||||
if basic_result is None:
|
||||
future_old, future_new = self.process_question(q)
|
||||
if future_old is None or future_new is None:
|
||||
return None
|
||||
|
||||
query = basic_result["问题"]
|
||||
old_answer = basic_result["旧流程答案"]
|
||||
new_answer = basic_result["新流程答案"]
|
||||
query = future_old["问题"]
|
||||
old_answer = future_old["旧流程答案"]
|
||||
new_answer = future_new["新流程答案"]
|
||||
|
||||
# 获取词条链接和标准答案
|
||||
wiki_url = self.find_wiki_link(query)
|
||||
@@ -540,33 +549,23 @@ content: "{content}"
|
||||
|
||||
if judge_result is None:
|
||||
judge_result = ""
|
||||
|
||||
# retrieve_title_score = self.get_retrieve_title_similarity(old_retrieve_content=old_workflow_info["检索内容"], new_retrieve_content=new_workflow_info["检索内容"])
|
||||
|
||||
# 并行获取新旧流程信息
|
||||
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||
future_new = executor.submit(self.get_new_workflow_info, query=query, new_message_id=new_message_id)
|
||||
future_old = executor.submit(self.get_old_workflow_info, query=query, old_message_id=old_message_id)
|
||||
|
||||
try:
|
||||
new_workflow_info = future_new.result()
|
||||
old_workflow_info = future_old.result()
|
||||
except Exception as e:
|
||||
print(f"处理问题 '{query}' 获取工作流信息时发生错误: {str(e)}")
|
||||
return None
|
||||
retrieve_title_score=self.get_retrieve_title_similarity(old_retrieve_content=old_workflow_info["检索内容"], new_retrieve_content=new_workflow_info["检索内容"])
|
||||
# 返回结果
|
||||
return {
|
||||
"问题": query,
|
||||
"新问题改写": new_workflow_info["问题改写"],
|
||||
"旧问题改写": old_workflow_info["问题改写"],
|
||||
"新问题分类": new_workflow_info["问题分类"],
|
||||
"槽点信息":new_workflow_info["槽点信息"],
|
||||
"新问题改写": future_new["问题改写"],
|
||||
"旧问题改写": future_old["问题改写"],
|
||||
"新问题分类": future_new["问题分类"],
|
||||
"槽点信息": future_new["槽点信息"],
|
||||
"新流程答案": new_answer,
|
||||
"旧流程答案": old_answer,
|
||||
"回答判断": judge_result,
|
||||
"词条检索相似度": retrieve_title_score,
|
||||
# "词条检索相似度": retrieve_title_score,
|
||||
"答案词条": answer_title if answer_title else "",
|
||||
"新检索词条": new_workflow_info["检索词条"],
|
||||
"旧检索词条": old_workflow_info["检索词条"],
|
||||
"新检索词条": future_new["检索词条"],
|
||||
"旧检索词条": future_old["检索词条"],
|
||||
}
|
||||
|
||||
def run_comparison(self, with_judge=False):
|
||||
@@ -670,5 +669,7 @@ if __name__ == "__main__":
|
||||
print(f"对比评判结果已保存至: {output_file}")
|
||||
|
||||
# 单个问题测试示例
|
||||
# c = DifyChat(baseurl="http://172.20.0.145/v1", api_key="app-LjJaeLoAfqa6aoGzqU9UvxSf")
|
||||
# c.chat("如何新建配电线路工程")
|
||||
# 使用新的工作流类进行测试
|
||||
# new_chat = NewWorkflowChat(api_key="app-qxsSybCs7ABiKlC1JabTYVn6", base_url="http://172.20.0.145/v1")
|
||||
# result = new_chat.process_question("如何新建配电线路工程")
|
||||
# print(json.dumps(result, ensure_ascii=False, indent=2))
|
||||
|
||||
@@ -1,310 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from rag2_0.dify.dify_client import ChatClient, DifyClient
|
||||
from rag2_0.dify.dify_tool import DifyTool
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from threading import Lock
|
||||
|
||||
class ContentSource(BaseModel):
|
||||
score: int = Field(description="相关性分数")
|
||||
reason: str = Field(description="评分理由")
|
||||
|
||||
class BaseWorkflowChat:
|
||||
"""
|
||||
工作流对话基类,封装了与Dify API交互的基本功能
|
||||
"""
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
"""
|
||||
初始化工作流对话基类
|
||||
|
||||
Args:
|
||||
api_key: Dify API的密钥
|
||||
base_url: Dify API的基础URL
|
||||
"""
|
||||
self.chat_client = ChatClient(api_key=api_key, base_url=base_url)
|
||||
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
|
||||
|
||||
def create_chat_message(self, query: str):
|
||||
"""
|
||||
创建聊天消息
|
||||
|
||||
Args:
|
||||
query: 问题内容
|
||||
|
||||
Returns:
|
||||
tuple: (聊天响应, 消息ID)
|
||||
"""
|
||||
try:
|
||||
response = self.chat_client.create_chat_message(inputs={}, query=query, user="AutoTestDifyChat").json()
|
||||
return response, response["message_id"]
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def calculate_score(self, query: str, content: str) -> int:
|
||||
"""
|
||||
使用LLM判断query与content之间的相关性分数
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
content (str): 检索内容
|
||||
|
||||
Returns:
|
||||
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
|
||||
"""
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
|
||||
try:
|
||||
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
|
||||
|
||||
【评分标准】
|
||||
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
|
||||
8-9分:高度相关,核心要素匹配但存在少量信息缺失
|
||||
6-7分:部分相关,涉及相同主题但存在重要信息缺失
|
||||
4-5分:弱相关,仅次要信息点匹配
|
||||
1-3分:完全不相关或信息冲突
|
||||
|
||||
【评估维度】
|
||||
1. 主题一致性:核心主题/意图的匹配程度
|
||||
2. 内容覆盖度:是否涵盖query的关键要素
|
||||
3. 信息准确性:是否存在矛盾/错误信息
|
||||
4. 细节丰富度:是否提供query要求的详细信息
|
||||
|
||||
【输出格式】
|
||||
{{
|
||||
"score": 评分,
|
||||
"reason": "简明扼要的评分理由(中文)"
|
||||
}}
|
||||
|
||||
【示例】
|
||||
query: "新冠疫苗的常见副作用"
|
||||
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)"
|
||||
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
|
||||
|
||||
现在评估:
|
||||
query: "{query}"
|
||||
content: "{content}"
|
||||
"""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model = os.getenv("LLM_MODEL_NAME")
|
||||
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model=model)
|
||||
response = llm.invoke(user_prompt=prompt, need_retry=True)
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
parsed_output = self.content_source_parser.parse(response.content)
|
||||
return parsed_output.score
|
||||
except Exception as e:
|
||||
return -1
|
||||
except Exception as e:
|
||||
return -1
|
||||
|
||||
def get_retrieve_info(self, query: str, outputs: dict) -> tuple:
|
||||
"""
|
||||
获取检索信息并计算分数
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
outputs (dict): 检索输出结果
|
||||
|
||||
Returns:
|
||||
tuple: (检索内容列表, 最高分, 最低分, 平均分)
|
||||
"""
|
||||
max_score = 0
|
||||
min_score = 10
|
||||
total_score = 0
|
||||
valid_scores = 0
|
||||
retrieve_content = []
|
||||
|
||||
# 使用线程池并发计算分数
|
||||
with ThreadPoolExecutor() as executor:
|
||||
# 创建任务列表
|
||||
future_to_content = {}
|
||||
for result in outputs["result"]:
|
||||
content = result["content"].strip()
|
||||
future = executor.submit(self.calculate_score, query=query, content=content)
|
||||
future_to_content[future] = content
|
||||
|
||||
# 收集结果
|
||||
for future in as_completed(future_to_content):
|
||||
content = future_to_content[future]
|
||||
score = future.result()
|
||||
content_title = content.split("\n")[0]
|
||||
|
||||
if score != -1:
|
||||
max_score = max(max_score, score)
|
||||
min_score = min(min_score, score)
|
||||
total_score += score
|
||||
valid_scores += 1
|
||||
|
||||
if content_title:
|
||||
retrieve_content.append(content_title + f"--相关性得分({score}分)")
|
||||
|
||||
avg_score = total_score / valid_scores if valid_scores > 0 else 0
|
||||
return retrieve_content, max_score, min_score, avg_score
|
||||
|
||||
|
||||
class NewWorkflowChat(BaseWorkflowChat):
|
||||
"""
|
||||
新工作流对话类,用于调用新工作流发送对话并解析获取相关数据
|
||||
"""
|
||||
def process_question(self, query: str) -> dict:
|
||||
"""
|
||||
处理问题,获取新工作流的回答和相关信息
|
||||
|
||||
Args:
|
||||
query: 问题内容
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和相关信息的字典
|
||||
"""
|
||||
response, message_id = self.create_chat_message(query)
|
||||
|
||||
if isinstance(response, str) and response.startswith("error:"):
|
||||
raise RuntimeError(f"create_chat_message 出错:{response}")
|
||||
|
||||
answer = response["answer"]
|
||||
workflow_info = self.get_workflow_info(query, message_id)
|
||||
|
||||
if workflow_info is None:
|
||||
return None
|
||||
|
||||
result = {
|
||||
"问题": query,
|
||||
"新流程答案": answer,
|
||||
"新问题改写": workflow_info["问题改写"],
|
||||
"新问题分类": workflow_info["问题分类"],
|
||||
"槽点信息": workflow_info["槽点信息"],
|
||||
"新检索词条": workflow_info["检索词条"],
|
||||
"检索内容": workflow_info["检索内容"],
|
||||
"message_id":message_id
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_workflow_info(self, query: str, message_id: str) -> dict:
|
||||
"""
|
||||
获取新工作流的问题分类和检索信息
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
message_id (str): 新工作流的消息ID
|
||||
|
||||
Returns:
|
||||
dict: 包含问题分类结果的字典
|
||||
"""
|
||||
retrieve_title = []
|
||||
retrieve_content = []
|
||||
max_score = 0
|
||||
min_score = 0
|
||||
avg_score = 0
|
||||
rewrite_query = ""
|
||||
vertical_classification = ""
|
||||
sub_classification = ""
|
||||
slot_info = ""
|
||||
|
||||
try:
|
||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||
retrieve_content = outputs["result"]
|
||||
elif workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
llm_result_json = json.loads(workflow_node['inputs'])["llm_result"]
|
||||
json_result = json.loads(llm_result_json)
|
||||
vertical_classification = json_result['vertical_classification']
|
||||
sub_classification = json_result['sub_classification']
|
||||
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
return {
|
||||
"问题改写": rewrite_query,
|
||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||
"检索内容": retrieve_content,
|
||||
"问题分类": f"{vertical_classification} - {sub_classification}",
|
||||
"槽点信息": slot_info,
|
||||
|
||||
}
|
||||
|
||||
|
||||
class OldWorkFlowChat(BaseWorkflowChat):
|
||||
"""
|
||||
旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据
|
||||
"""
|
||||
def process_question(self, query: str) -> dict:
|
||||
"""
|
||||
处理问题,获取旧工作流的回答和相关信息
|
||||
|
||||
Args:
|
||||
query: 问题内容
|
||||
|
||||
Returns:
|
||||
dict: 包含问题、回答和相关信息的字典
|
||||
"""
|
||||
response, message_id = self.create_chat_message(query)
|
||||
|
||||
if isinstance(response, str) and response.startswith("error:"):
|
||||
return None
|
||||
|
||||
answer = response["answer"]
|
||||
workflow_info = self.get_workflow_info(query, message_id)
|
||||
|
||||
if workflow_info is None:
|
||||
return None
|
||||
|
||||
result = {
|
||||
"问题": query,
|
||||
"旧流程答案": answer,
|
||||
"旧问题改写": workflow_info["问题改写"],
|
||||
"旧检索词条": workflow_info["检索词条"],
|
||||
"检索内容": workflow_info["检索内容"],
|
||||
"message_id":message_id
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def get_workflow_info(self, query: str, message_id: str) -> dict:
|
||||
"""
|
||||
获取旧工作流的问题改写和检索信息
|
||||
|
||||
Args:
|
||||
query (str): 用户问题
|
||||
message_id (str): 旧工作流的消息ID
|
||||
|
||||
Returns:
|
||||
dict: 包含问题改写和检索信息的字典
|
||||
"""
|
||||
retrieve_title = []
|
||||
retrieve_content = []
|
||||
max_score = 0
|
||||
min_score = 0
|
||||
avg_score = 0
|
||||
rewrite_query = ""
|
||||
|
||||
try:
|
||||
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||
retrieve_content = outputs["result"]
|
||||
elif workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
return {
|
||||
"问题改写": rewrite_query,
|
||||
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
|
||||
"检索内容": retrieve_content,
|
||||
}
|
||||
Reference in New Issue
Block a user