新增对话处理功能,优化意图识别逻辑,添加结果保存至Excel的功能,更新依赖项以支持新的数据库驱动和ORM,重构代码以提高可读性和维护性,删除冗余文件以简化项目结构。

This commit is contained in:
2025-06-10 17:00:40 +08:00
parent ea705756fa
commit 2e91063ad1
11 changed files with 653 additions and 493 deletions
+75 -34
View File
@@ -118,10 +118,65 @@ def process_query(recognizer, query):
}
else:
# 可以在这里添加延迟,避免过快重试
time.sleep(10 * retry_count)
time.sleep(10)
def save_results_to_excel(results, output_file, is_final=False):
"""
将结果保存到Excel文件
Args:
results: 结果列表
output_file: 输出文件路径
is_final: 是否为最终保存,如果是则使用完整文件名,否则添加临时标记
Returns:
None
"""
# 过滤掉None值
valid_results = [r for r in results if r is not None]
if not valid_results:
logging.warning("没有有效结果可保存")
return
# 创建DataFrame
results_df = pd.DataFrame(valid_results)
# 根据是否为最终保存确定文件名
if not is_final:
file_name, file_ext = os.path.splitext(output_file)
temp_output_file = f"{file_name}_temp{file_ext}"
else:
temp_output_file = output_file
# 使用ExcelWriter设置格式
with pd.ExcelWriter(temp_output_file, engine='xlsxwriter') as writer:
results_df.to_excel(writer, index=False, sheet_name='Sheet1')
# 获取工作簿和工作表对象
workbook = writer.book
worksheet = writer.sheets['Sheet1']
# 设置列宽(单位:像素)
# 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位)
worksheet.set_column('A:A', 60) # 提问列 60个Excel单位
worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位
worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位
worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
worksheet.set_column('G:G', 80) # 槽位填充 80个Excel单位
# 设置所有行高为20磅
for i in range(len(results_df) + 1): # +1 是为了包括表头
worksheet.set_row(i, 20)
logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}")
# 示例查询
examples_query = """储能软件组合件界面,点击隐藏空项目划分后界面没有任何变化"""
examples_query = """"锁标签号:811621005858, 注册单位:惠州电力勘察设计院有限公司,软件名称:广东迁改导则2022, 注册号:BW278-83834-58155-58339.迁改导则是要另外下载安装软件吗?"
"""
def main():
"""
@@ -138,10 +193,10 @@ def main():
# 读取提问数据
current_dir = os.path.dirname(os.path.abspath(__file__))
data_file = os.path.join(current_dir, "..", "..", "data", "excel", "400条提问意图分类数据-原始.xlsx")
data_file = os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_提问明确.xlsx")
output_file = os.path.join(current_dir, "..", "..", "data", "excel", "测试提问数据_槽位填充结果.xlsx")
# 检测是否为调试模式,调试模式下使用examples_query,否则从Excel读取
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
if is_debug:
examples = examples_query.strip().split("\n")
@@ -149,11 +204,13 @@ def main():
examples = load_questions_from_excel(data_file)
if not is_debug:
max_workers = 10 # 减少并发数以避免API限制
max_workers = 20 # 减少并发数以避免API限制
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
# 创建一个与输入顺序相同的结果列表
results = [None] * len(examples)
batch_size = 100 # 每100条保存一次
# 使用线程池进行并发处理
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务并记录它们的索引
@@ -163,43 +220,27 @@ def main():
future_to_index[future] = idx
# 使用tqdm显示进度条
completed = 0
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(examples), desc="处理进度"):
idx = future_to_index[future]
result = future.result()
# 将结果放在与输入相同的位置
results[idx] = result
completed += 1
# 每处理batch_size条数据保存一次
if completed % batch_size == 0:
logging.info(f"已完成 {completed}/{len(examples)} 条,保存中间结果...")
save_results_to_excel(results, output_file, is_final=False)
# 将结果保存到Excel文件
results_df = pd.DataFrame(results)
output_file = os.path.join(current_dir, "..", "..", "data", "excel", "测试提问数据_槽位填充结果.xlsx")
# 使用ExcelWriter设置格式
with pd.ExcelWriter(output_file, engine='xlsxwriter') as writer:
results_df.to_excel(writer, index=False, sheet_name='Sheet1')
# 获取工作簿和工作表对象
workbook = writer.book
worksheet = writer.sheets['Sheet1']
# 设置列宽(单位:像素)
# 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位)
worksheet.set_column('A:A', 60) # 提问列 60个Excel单位
worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位
worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位
worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
worksheet.set_column('G:G', 80) # 槽位填充 80个Excel单位
# 设置所有行高为20磅
for i in range(len(results_df) + 1): # +1 是为了包括表头
worksheet.set_row(i, 20)
# 处理完所有数据后,保存最终结果
save_results_to_excel(results, output_file, is_final=True)
logging.info(f"所有处理完成,最终结果已保存至: {output_file}")
else:
for idx, query in enumerate(examples):
if query.strip() == "":
continue
process_query(recognizer, query)
logging.info(f"处理完成,结果已保存至: {output_file}")
def setup_logging():
# 配置日志输出到控制台
+1 -1
View File
@@ -1,4 +1,4 @@
from rag2_0.dify.workflow_chat import NewWorkflowChat
from rag2_0.dify.dify_tool import NewWorkflowChat
import pandas as pd
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
+304 -2
View File
@@ -1,8 +1,17 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import psycopg2
from psycopg2 import sql
import os
import json
from datetime import timezone, timedelta
from concurrent.futures import ThreadPoolExecutor, as_completed
from rag2_0.dify.dify_client import ChatClient
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
class ContentSource(BaseModel):
score: int = Field(description="相关性分数")
reason: str = Field(description="评分理由")
class PgSql:
"""
@@ -219,6 +228,299 @@ class DifyTool:
finally:
dify_pgsql.close_connection()
class BaseWorkflowChat:
"""
工作流对话基类,封装了与Dify API交互的基本功能
"""
def __init__(self, api_key: str, base_url: str):
"""
初始化工作流对话基类
Args:
api_key: Dify API的密钥
base_url: Dify API的基础URL
"""
self.chat_client = ChatClient(api_key=api_key, base_url=base_url)
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
def create_chat_message(self, query: str):
"""
创建聊天消息
Args:
query: 问题内容
Returns:
tuple: (聊天响应, 消息ID)
"""
try:
response = self.chat_client.create_chat_message(inputs={}, query=query, user="AutoTestDifyChat").json()
return response, response["message_id"]
except Exception as e:
raise e
def calculate_score(self, query: str, content: str) -> int:
"""
使用LLM判断query与content之间的相关性分数
Args:
query (str): 用户问题
content (str): 检索内容
Returns:
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
"""
from rag2_0.tool.ModelTool import OpenAiLLM
try:
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
【评分标准】
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
8-9分:高度相关,核心要素匹配但存在少量信息缺失
6-7分:部分相关,涉及相同主题但存在重要信息缺失
4-5分:弱相关,仅次要信息点匹配
1-3分:完全不相关或信息冲突
【评估维度】
1. 主题一致性:核心主题/意图的匹配程度
2. 内容覆盖度:是否涵盖query的关键要素
3. 信息准确性:是否存在矛盾/错误信息
4. 细节丰富度:是否提供query要求的详细信息
【输出格式】
{{
"score": 评分,
"reason": "简明扼要的评分理由(中文)"
}}
【示例】
query: "新冠疫苗的常见副作用"
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)"
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
现在评估:
query: "{query}"
content: "{content}"
"""
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model = os.getenv("LLM_MODEL_NAME")
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model=model)
response = llm.invoke(user_prompt=prompt, need_retry=True)
# 解析JSON响应
try:
parsed_output = self.content_source_parser.parse(response.content)
return parsed_output.score
except Exception as e:
return -1
except Exception as e:
return -1
def get_retrieve_info(self, query: str, outputs: dict) -> tuple:
"""
获取检索信息并计算分数
Args:
query (str): 用户问题
outputs (dict): 检索输出结果
Returns:
tuple: (检索内容列表, 最高分, 最低分, 平均分)
"""
max_score = 0
min_score = 10
total_score = 0
valid_scores = 0
retrieve_content = []
# 使用线程池并发计算分数
with ThreadPoolExecutor() as executor:
# 创建任务列表
future_to_content = {}
for result in outputs["result"]:
content = result["content"].strip()
future = executor.submit(self.calculate_score, query=query, content=content)
future_to_content[future] = content
# 收集结果
for future in as_completed(future_to_content):
content = future_to_content[future]
score = future.result()
content_title = content.split("\n")[0]
if score != -1:
max_score = max(max_score, score)
min_score = min(min_score, score)
total_score += score
valid_scores += 1
if content_title:
retrieve_content.append(content_title + f"--相关性得分({score}分)")
avg_score = total_score / valid_scores if valid_scores > 0 else 0
return retrieve_content, max_score, min_score, avg_score
class NewWorkflowChat(BaseWorkflowChat):
"""
新工作流对话类,用于调用新工作流发送对话并解析获取相关数据
"""
def process_question(self, query: str) -> dict:
"""
处理问题,获取新工作流的回答和相关信息
Args:
query: 问题内容
Returns:
dict: 包含问题、回答和相关信息的字典
"""
response, message_id = self.create_chat_message(query)
if isinstance(response, str) and response.startswith("error:"):
raise RuntimeError(f"create_chat_message 出错:{response}")
answer = response["answer"]
workflow_info = self.get_workflow_info(query, message_id)
if workflow_info is None:
return None
result = {
"问题": query,
"新流程答案": answer,
"新问题改写": workflow_info["问题改写"],
"新问题分类": workflow_info["问题分类"],
"槽点信息": workflow_info["槽点信息"],
"新检索词条": workflow_info["检索词条"],
"检索内容": workflow_info["检索内容"],
"message_id":message_id
}
return result
def get_workflow_info(self, query: str, message_id: str) -> dict:
"""
获取新工作流的问题分类和检索信息
Args:
query (str): 用户问题
message_id (str): 新工作流的消息ID
Returns:
dict: 包含问题分类结果的字典
"""
retrieve_title = []
retrieve_content = []
max_score = 0
min_score = 0
avg_score = 0
rewrite_query = ""
vertical_classification = ""
sub_classification = ""
slot_info = ""
try:
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
for workflow_node in message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "知识检索结果后处理":
outputs = json.loads(workflow_node["outputs"])
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
retrieve_content = outputs["result"]
elif workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
llm_result_json = json.loads(workflow_node['inputs'])["llm_result"]
json_result = json.loads(llm_result_json)
vertical_classification = json_result['vertical_classification']
sub_classification = json_result['sub_classification']
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
except Exception as e:
raise e
return {
"问题改写": rewrite_query,
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
"检索内容": retrieve_content,
"问题分类": f"{vertical_classification} - {sub_classification}",
"槽点信息": slot_info,
}
class OldWorkFlowChat(BaseWorkflowChat):
"""
旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据
"""
def process_question(self, query: str) -> dict:
"""
处理问题,获取旧工作流的回答和相关信息
Args:
query: 问题内容
Returns:
dict: 包含问题、回答和相关信息的字典
"""
response, message_id = self.create_chat_message(query)
if isinstance(response, str) and response.startswith("error:"):
return None
answer = response["answer"]
workflow_info = self.get_workflow_info(query, message_id)
if workflow_info is None:
return None
result = {
"问题": query,
"旧流程答案": answer,
"旧问题改写": workflow_info["问题改写"],
"旧检索词条": workflow_info["检索词条"],
"检索内容": workflow_info["检索内容"],
"message_id":message_id
}
return result
def get_workflow_info(self, query: str, message_id: str) -> dict:
"""
获取旧工作流的问题改写和检索信息
Args:
query (str): 用户问题
message_id (str): 旧工作流的消息ID
Returns:
dict: 包含问题改写和检索信息的字典
"""
retrieve_title = []
retrieve_content = []
max_score = 0
min_score = 0
avg_score = 0
rewrite_query = ""
try:
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
for workflow_node in message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "知识检索结果后处理":
outputs = json.loads(workflow_node["outputs"])
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
retrieve_content = outputs["result"]
elif workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
except Exception as e:
return None
return {
"问题改写": rewrite_query,
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
"检索内容": retrieve_content,
}
if __name__ == "__main__":
try:
+57 -56
View File
@@ -2,7 +2,8 @@
# -*- coding: utf-8 -*-
import os
from rag2_0.dify.dify_client import ChatClient, DifyClient
from rag2_0.dify.dify_client import DifyClient
from rag2_0.dify.dify_tool import NewWorkflowChat, OldWorkFlowChat
import pandas as pd
# 使用线程池并发执行
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -44,8 +45,9 @@ class DifyComparisonTester:
max_workers: 最大工作线程数
"""
self.excel_path = excel_path
self.old_chat = ChatClient(api_key=old_workflow_api_key, base_url=baseurl)
self.new_chat = ChatClient(api_key=new_workflow_api_key, base_url=baseurl)
# 使用NewWorkflowChat和OldWorkFlowChat代替ChatClient
self.old_chat = OldWorkFlowChat(api_key=old_workflow_api_key, base_url=baseurl)
self.new_chat = NewWorkflowChat(api_key=new_workflow_api_key, base_url=baseurl)
# 评判相关参数
self.output_path = output_path or os.path.join(os.path.dirname(self.excel_path), "dify问答_综合评判结果.xlsx")
@@ -78,13 +80,13 @@ class DifyComparisonTester:
"""
def get_old_answer():
try:
return self.old_chat.create_chat_message(inputs={}, query=q, user="AutoTestDifyChat").json()
return self.old_chat.process_question(query=q)
except Exception as e:
return f"error: {str(e)}"
def get_new_answer():
try:
return self.new_chat.create_chat_message(inputs={}, query=q, user="AutoTestDifyChat").json()
return self.new_chat.process_question(query=q)
except Exception as e:
return f"error: {str(e)}"
@@ -95,14 +97,15 @@ class DifyComparisonTester:
try:
old_result = future_old.result()
new_result = future_new.result()
old_message_id = old_result["message_id"]
new_message_id = new_result["message_id"]
if isinstance(old_result, str) and old_result.startswith("error:"):
return None, None
if isinstance(new_result, str) and new_result.startswith("error:"):
return None, None
old_answer = old_result["answer"]
new_answer = new_result["answer"]
except Exception as e:
return None, None, None
return {"问题": q, "旧流程答案": old_answer, "新流程答案": new_answer}, old_message_id, new_message_id
return future_old, future_new
def find_wiki_link(self, query) -> str | None:
"""
@@ -407,22 +410,24 @@ content: "{content}"
Returns:
dict: 包含问题分类结果的字典
"""
retrieve_title=[]
retrieve_content=[]
max_score=0
min_score=0
avg_score=0
rewrite_query=""
vertical_classification=""
sub_classification=""
slot_info=""
try:
# 使用DifyTool直接获取消息信息
new_message_info = DifyTool.get_message_debug_info_by_id(message_id=new_message_id)
# 初始化变量
retrieve_title = []
retrieve_content = []
rewrite_query = ""
vertical_classification = ""
sub_classification = ""
slot_info = ""
# 解析工作流节点信息
for workflow_node in new_message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "知识检索结果后处理":
outputs = json.loads(workflow_node["outputs"])
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
retrieve_content=outputs["result"]
retrieve_content = outputs["result"]
elif workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
@@ -430,20 +435,21 @@ content: "{content}"
json_result = json.loads(llm_result_json)
vertical_classification = json_result['vertical_classification']
sub_classification = json_result['sub_classification']
slot_info=json.dumps(json_result["slot_filling"],ensure_ascii=False,indent=2)
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
except Exception as e:
return None
return {
"问题改写": rewrite_query,
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
"检索内容": retrieve_content,
"问题分类": f"{vertical_classification} - {sub_classification}",
"槽点信息":slot_info
"槽点信息": slot_info
}
def get_old_workflow_info(self, query:str, old_message_id:str) -> dict:
"""
获取流程的问题分类
获取流程的问题分类
Args:
query (str): 用户问题
@@ -452,24 +458,27 @@ content: "{content}"
Returns:
dict: 包含问题分类结果的字典
"""
retrieve_title=[]
retrieve_content=[]
max_score=0
min_score=0
avg_score=0
rewrite_query=""
try:
# 使用DifyTool直接获取消息信息
old_message_info = DifyTool.get_message_debug_info_by_id(message_id=old_message_id)
# 初始化变量
retrieve_title = []
retrieve_content = []
rewrite_query = ""
# 解析工作流节点信息
for workflow_node in old_message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "知识检索结果后处理":
outputs = json.loads(workflow_node["outputs"])
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
retrieve_content=outputs["result"]
retrieve_content = outputs["result"]
elif workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
except Exception as e:
return None
return {
"问题改写": rewrite_query,
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
@@ -512,13 +521,13 @@ content: "{content}"
dict: 包含问题、回答和评判结果的字典
"""
# 获取基本的问题和回答
basic_result, old_message_id, new_message_id = self.process_question(q)
if basic_result is None:
future_old, future_new = self.process_question(q)
if future_old is None or future_new is None:
return None
query = basic_result["问题"]
old_answer = basic_result["旧流程答案"]
new_answer = basic_result["新流程答案"]
query = future_old["问题"]
old_answer = future_old["旧流程答案"]
new_answer = future_new["新流程答案"]
# 获取词条链接和标准答案
wiki_url = self.find_wiki_link(query)
@@ -540,33 +549,23 @@ content: "{content}"
if judge_result is None:
judge_result = ""
# retrieve_title_score = self.get_retrieve_title_similarity(old_retrieve_content=old_workflow_info["检索内容"], new_retrieve_content=new_workflow_info["检索内容"])
# 并行获取新旧流程信息
with ThreadPoolExecutor(max_workers=2) as executor:
future_new = executor.submit(self.get_new_workflow_info, query=query, new_message_id=new_message_id)
future_old = executor.submit(self.get_old_workflow_info, query=query, old_message_id=old_message_id)
try:
new_workflow_info = future_new.result()
old_workflow_info = future_old.result()
except Exception as e:
print(f"处理问题 '{query}' 获取工作流信息时发生错误: {str(e)}")
return None
retrieve_title_score=self.get_retrieve_title_similarity(old_retrieve_content=old_workflow_info["检索内容"], new_retrieve_content=new_workflow_info["检索内容"])
# 返回结果
return {
"问题": query,
"新问题改写": new_workflow_info["问题改写"],
"旧问题改写": old_workflow_info["问题改写"],
"新问题分类": new_workflow_info["问题分类"],
"槽点信息":new_workflow_info["槽点信息"],
"新问题改写": future_new["问题改写"],
"旧问题改写": future_old["问题改写"],
"新问题分类": future_new["问题分类"],
"槽点信息": future_new["槽点信息"],
"新流程答案": new_answer,
"旧流程答案": old_answer,
"回答判断": judge_result,
"词条检索相似度": retrieve_title_score,
# "词条检索相似度": retrieve_title_score,
"答案词条": answer_title if answer_title else "",
"新检索词条": new_workflow_info["检索词条"],
"旧检索词条": old_workflow_info["检索词条"],
"新检索词条": future_new["检索词条"],
"旧检索词条": future_old["检索词条"],
}
def run_comparison(self, with_judge=False):
@@ -670,5 +669,7 @@ if __name__ == "__main__":
print(f"对比评判结果已保存至: {output_file}")
# 单个问题测试示例
# c = DifyChat(baseurl="http://172.20.0.145/v1", api_key="app-LjJaeLoAfqa6aoGzqU9UvxSf")
# c.chat("如何新建配电线路工程")
# 使用新的工作流类进行测试
# new_chat = NewWorkflowChat(api_key="app-qxsSybCs7ABiKlC1JabTYVn6", base_url="http://172.20.0.145/v1")
# result = new_chat.process_question("如何新建配电线路工程")
# print(json.dumps(result, ensure_ascii=False, indent=2))
-310
View File
@@ -1,310 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
import os
import json
from concurrent.futures import ThreadPoolExecutor, as_completed
from rag2_0.dify.dify_client import ChatClient, DifyClient
from rag2_0.dify.dify_tool import DifyTool
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
from threading import Lock
class ContentSource(BaseModel):
score: int = Field(description="相关性分数")
reason: str = Field(description="评分理由")
class BaseWorkflowChat:
"""
工作流对话基类,封装了与Dify API交互的基本功能
"""
def __init__(self, api_key: str, base_url: str):
"""
初始化工作流对话基类
Args:
api_key: Dify API的密钥
base_url: Dify API的基础URL
"""
self.chat_client = ChatClient(api_key=api_key, base_url=base_url)
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
def create_chat_message(self, query: str):
"""
创建聊天消息
Args:
query: 问题内容
Returns:
tuple: (聊天响应, 消息ID)
"""
try:
response = self.chat_client.create_chat_message(inputs={}, query=query, user="AutoTestDifyChat").json()
return response, response["message_id"]
except Exception as e:
raise e
def calculate_score(self, query: str, content: str) -> int:
"""
使用LLM判断query与content之间的相关性分数
Args:
query (str): 用户问题
content (str): 检索内容
Returns:
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
"""
from rag2_0.tool.ModelTool import OpenAiLLM
try:
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
【评分标准】
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
8-9分:高度相关,核心要素匹配但存在少量信息缺失
6-7分:部分相关,涉及相同主题但存在重要信息缺失
4-5分:弱相关,仅次要信息点匹配
1-3分:完全不相关或信息冲突
【评估维度】
1. 主题一致性:核心主题/意图的匹配程度
2. 内容覆盖度:是否涵盖query的关键要素
3. 信息准确性:是否存在矛盾/错误信息
4. 细节丰富度:是否提供query要求的详细信息
【输出格式】
{{
"score": 评分,
"reason": "简明扼要的评分理由(中文)"
}}
【示例】
query: "新冠疫苗的常见副作用"
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)"
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
现在评估:
query: "{query}"
content: "{content}"
"""
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model = os.getenv("LLM_MODEL_NAME")
llm = OpenAiLLM(api_key=api_key, base_url=base_url, model=model)
response = llm.invoke(user_prompt=prompt, need_retry=True)
# 解析JSON响应
try:
parsed_output = self.content_source_parser.parse(response.content)
return parsed_output.score
except Exception as e:
return -1
except Exception as e:
return -1
def get_retrieve_info(self, query: str, outputs: dict) -> tuple:
"""
获取检索信息并计算分数
Args:
query (str): 用户问题
outputs (dict): 检索输出结果
Returns:
tuple: (检索内容列表, 最高分, 最低分, 平均分)
"""
max_score = 0
min_score = 10
total_score = 0
valid_scores = 0
retrieve_content = []
# 使用线程池并发计算分数
with ThreadPoolExecutor() as executor:
# 创建任务列表
future_to_content = {}
for result in outputs["result"]:
content = result["content"].strip()
future = executor.submit(self.calculate_score, query=query, content=content)
future_to_content[future] = content
# 收集结果
for future in as_completed(future_to_content):
content = future_to_content[future]
score = future.result()
content_title = content.split("\n")[0]
if score != -1:
max_score = max(max_score, score)
min_score = min(min_score, score)
total_score += score
valid_scores += 1
if content_title:
retrieve_content.append(content_title + f"--相关性得分({score}分)")
avg_score = total_score / valid_scores if valid_scores > 0 else 0
return retrieve_content, max_score, min_score, avg_score
class NewWorkflowChat(BaseWorkflowChat):
"""
新工作流对话类,用于调用新工作流发送对话并解析获取相关数据
"""
def process_question(self, query: str) -> dict:
"""
处理问题,获取新工作流的回答和相关信息
Args:
query: 问题内容
Returns:
dict: 包含问题、回答和相关信息的字典
"""
response, message_id = self.create_chat_message(query)
if isinstance(response, str) and response.startswith("error:"):
raise RuntimeError(f"create_chat_message 出错:{response}")
answer = response["answer"]
workflow_info = self.get_workflow_info(query, message_id)
if workflow_info is None:
return None
result = {
"问题": query,
"新流程答案": answer,
"新问题改写": workflow_info["问题改写"],
"新问题分类": workflow_info["问题分类"],
"槽点信息": workflow_info["槽点信息"],
"新检索词条": workflow_info["检索词条"],
"检索内容": workflow_info["检索内容"],
"message_id":message_id
}
return result
def get_workflow_info(self, query: str, message_id: str) -> dict:
"""
获取新工作流的问题分类和检索信息
Args:
query (str): 用户问题
message_id (str): 新工作流的消息ID
Returns:
dict: 包含问题分类结果的字典
"""
retrieve_title = []
retrieve_content = []
max_score = 0
min_score = 0
avg_score = 0
rewrite_query = ""
vertical_classification = ""
sub_classification = ""
slot_info = ""
try:
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
for workflow_node in message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "知识检索结果后处理":
outputs = json.loads(workflow_node["outputs"])
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
retrieve_content = outputs["result"]
elif workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
llm_result_json = json.loads(workflow_node['inputs'])["llm_result"]
json_result = json.loads(llm_result_json)
vertical_classification = json_result['vertical_classification']
sub_classification = json_result['sub_classification']
slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2)
except Exception as e:
raise e
return {
"问题改写": rewrite_query,
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
"检索内容": retrieve_content,
"问题分类": f"{vertical_classification} - {sub_classification}",
"槽点信息": slot_info,
}
class OldWorkFlowChat(BaseWorkflowChat):
"""
旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据
"""
def process_question(self, query: str) -> dict:
"""
处理问题,获取旧工作流的回答和相关信息
Args:
query: 问题内容
Returns:
dict: 包含问题、回答和相关信息的字典
"""
response, message_id = self.create_chat_message(query)
if isinstance(response, str) and response.startswith("error:"):
return None
answer = response["answer"]
workflow_info = self.get_workflow_info(query, message_id)
if workflow_info is None:
return None
result = {
"问题": query,
"旧流程答案": answer,
"旧问题改写": workflow_info["问题改写"],
"旧检索词条": workflow_info["检索词条"],
"检索内容": workflow_info["检索内容"],
"message_id":message_id
}
return result
def get_workflow_info(self, query: str, message_id: str) -> dict:
"""
获取旧工作流的问题改写和检索信息
Args:
query (str): 用户问题
message_id (str): 旧工作流的消息ID
Returns:
dict: 包含问题改写和检索信息的字典
"""
retrieve_title = []
retrieve_content = []
max_score = 0
min_score = 0
avg_score = 0
rewrite_query = ""
try:
message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id)
for workflow_node in message_info["workflow_node_executions_info"]:
if workflow_node["title"] == "知识检索结果后处理":
outputs = json.loads(workflow_node["outputs"])
retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
retrieve_content = outputs["result"]
elif workflow_node["title"] == "问题优化结果解析":
outputs = json.loads(workflow_node["outputs"])
rewrite_query = outputs["optimize_query"]
except Exception as e:
return None
return {
"问题改写": rewrite_query,
"检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库",
"检索内容": retrieve_content,
}
+10 -10
View File
@@ -23,12 +23,12 @@ class SoftwareName(str, Enum):
# 软件别名映射
ALIASES = {
D3: ["配网D3", "D3软件", "配网工程软件"],
C1: ["储能C1", "C1软件", "储能电站软件", "储能软件"],
Z1: ["西藏Z1", "Z1软件", "西藏电力软件"],
T1: ["技改T1", "T1软件", "技改检修软件"],
T1_LIST: ["技改清单T1", "T1清单软件", "技改检修清单软件"],
MAIN: ["主网软件", "电力建设软件", "主网建设软件", "主网软件"]
D3: "别名包括:配网D3、D3软件、配网工程软件等 其他类似称呼",
C1: "别名包括:储能C1、C1软件、储能电站软件、储能软件等 其他类似称呼",
Z1: "别名包括:西藏Z1、Z1软件、西藏电力软件等 其他类似称呼",
T1: "别名包括:技改T1、T1软件、技改检修软件等 其他类似称呼",
T1_LIST: "别名包括:技改清单T1、T1清单软件、技改检修清单软件等 其他类似称呼",
MAIN: "别名包括:主网软件、电力建设软件、主网建设软件、博微电力建设计价通等 其他类似称呼"
}
# 定义输出模型
@@ -58,7 +58,7 @@ class QueryRewrite(BaseModel):
# 1. 软件问题
# 1.1 软件功能
class SoftwareFunction(BaseModel):
software_name: SoftwareName = Field(description="软件名称")
software_name: SoftwareName = Field(description="软件名称,只能从给定的范围中取值")
function_name: str = Field(description="具体功能名称")
operation: str = Field(description="用户操作意图(如何使用功能、功能入口、功能使用场景)")
software_version: Optional[str] = Field(None, description="软件版本")
@@ -77,7 +77,7 @@ class SoftwareFunction(BaseModel):
# 1.2 故障排查
class TroubleShooting(BaseModel):
software_name: SoftwareName = Field(description="软件名称")
software_name: SoftwareName = Field(description="软件名称,只能从给定的范围中取值")
function_name: str = Field(description="具体功能名称/操作描述")
error_message: str = Field(description="报错信息/异常现象")
software_version: Optional[str] = Field(None, description="软件版本")
@@ -162,7 +162,7 @@ class SoftwareLock(BaseModel):
# 3.3 安装下载类
class InstallationDownload(BaseModel):
software_name: SoftwareName = Field(description="软件/插件名称,与file_name二选一")
software_name: str = Field(description="软件/插件名称,与file_name二选一")
file_name: str = Field(description="文件名,与software_name二选一")
operation_stage: str = Field(description="操作阶段")
os_version: Optional[str] = Field(None, description="操作系统版本")
@@ -182,7 +182,7 @@ class InstallationDownload(BaseModel):
# 3.4 问题排查类
class ProblemDiagnosis(BaseModel):
error_message: str = Field(description="报错信息/异常现象")
software_name: Optional[SoftwareName] = Field(None, description="软件名称")
software_name: Optional[SoftwareName] = Field(None, description="软件名称,只能从给定的范围中取值")
os_version: Optional[str] = Field(None, description="操作系统版本")
def check_required_slots(self) -> Tuple[bool, Dict[str, str]]:
@@ -180,7 +180,7 @@ class IntentRecognizer:
return reranked_terms
except Exception as e:
return list(matched_terms)
raise RuntimeError(f"SiliconFlowReRankerModel重排失败:{e}") from e
def match_keywords(self, query: str) -> Tuple[TermList, List[str]]:
"""
+19 -20
View File
@@ -18,26 +18,6 @@ import requests
# sk-dvbaktabkdwdpjgxyoozlwnejosjyhdgqwllfeborqahndxs
API_KEY_LIST=[
"sk-hrojkkkrrkmsajtnizokbcgexsfggdiqavbtvbayuwqbnmom",
"sk-kkdklmnyompoiotzkfqahpayzlkgogfudjkyaebehtsowvid",
"sk-sfxzvllifafbyfduupcdtcrjwhdyiyojnksyopnfslurnhsp",
"sk-faqirxiszukfswqvzqawxnemqfacrkyurbxxkzwbbujqacdp",
"sk-vonaanuueqiczppkntjuphateshrcpqpnvxmwxorkyihjmrb",
"sk-qfpeoodgupcukcdstjcxgegwxnuhtxkkrupkogkcvhavxgny",
"sk-fsvjnbpfgoadixympaabaukupuhjvbturcbxaqfdzjznemtr",
"sk-fltvnbiqntfawjwkfnnhmyfiimzgzxkweqmefcfqkbucwrhi",
"sk-oosswdriwyqkglwdigvcxgmcpyplcyowicbaugpizoscevdl",
"sk-jswtxhkiralnyiukqimtyuurcaepulxdrfijadtxzrgsajyc",
"sk-dcjuhoukdyrbneadtxtnyxzmigkpiqgtqqnreiprxpioftsv",
"sk-yrhezyuxjblpaxzzudbowqmvcoxcammupcubghbodolikbdk",
"sk-dsgvwpfagmarilmnewwbzhfzlqehburoupjaopucdvybpbdo",
"sk-oljjlspuaurtoczyekztiidwtoerugadgepiufclpmrbdfqc",
"sk-crgrimubjesthvxuqwedqqdoetljyrgeahxxpctfefgnkpyo",
"sk-tubqhwgycxrdhwsqzjopxgeaqpsjdfppckckayvzornaluwq",
"sk-amcxlmsdnadptpnehqnkvseolacipztmvovnmxojzohbjjil",
"sk-pdyymhshpzmdduwxsezthnrgarnnhgzvmiflbpisfzxkiayt",
"sk-qhwoorywmejumyudfxbrkegxtqifsbgcdkmpjckezepgyqnz",
"sk-cpoctrgcnstaybeyuieuwjdgeakudhqdnnwdjavjudcbvvem",
"sk-wqdpapdkisovziexgcyxvumpwzbjnhqbxvcqcspzctjhyhjk",
"sk-bbntrnifrtdzhhgrtlrhvwbnaysuszviemshdakxonnnymnb",
"sk-vmpnwjxersrwybmfhfxgsvbmhsmpjldxseiyxovnysrlbuzi",
@@ -98,6 +78,25 @@ API_KEY_LIST=[
"sk-nbksjgcngsayoumnsdbkcpnqivnvxjenwpzuazzrkhnsgeoo",
"sk-iaafvpjyqiocgzchbdldbkgcffqniahkcbgoviuevuogulcm",
"sk-muvjguqeshyimzowqnqgxwpsgujlpkqgrisxsimthtyrpypx",
"sk-jgybgyayxlwoxeijgrjcneqlyusleohgbliuwpsuhocrjsmk",
"sk-wzjsmwxcbbpcrqivqfzjwufqqjtlwejtncnvbpeicznkwiuh",
"sk-izdjicdoyillktsihkiapuvwebisehtlgykozrvzfkgncwsc",
"sk-fcsfmyivfuojsqsditvobfqprdpeunukycpcfnoxkraqevpx",
"sk-szyjgyxrcvyxpvzfwgmbxnflxngxvcplitcctsdvvrqjgftk",
"sk-jzbodthsnvjwbyrnynsxrudtqfnbdbrcxebjwjgajocnzqse",
"sk-fxepossfzpmccibfwqpkluorzqlbtcaplepeugtfzfsctcbl",
"sk-ympnflocrkxjrbubsxqdjqwicuyavvvysctlpfhunkcrzxjx",
"sk-flhqvziknntednkcgjaxlyzzsrfzjhrzrmteqonajpbiinni",
"sk-xfregpbbquqbxpiobjzanydsjivrjrnbokzxcqtnhxhyghhe",
"sk-jrdzerhmvrtvzawkksowbgkggkubwfquplmrxbdhespqgtis",
"sk-jjbpnkbeupsxyclcivbhizcfpfjrppddunbqynyjkqhtmpwu",
"sk-oqehupcveovkjqqtxypqyifidcdissuyehwrkdwgruoyjkpq",
"sk-orhfntzrbpmpavybcjyylofxncdvufdmvlznofmhxmnjymjl",
"sk-kvgfuqeqvpmfsccykyoohheshclcrtvjlnewratvrjpkpbkc",
"sk-zhnbqnpuumuuvegnvbgoggxafpukbzchpgrugpkobiwkzsar",
"sk-kzhxlqvqcxlnbdgnpalqnzumkmspepkttkgbophnkqanainw",
"sk-bzttugqtlskrvguvhckwamdssvgmgnrqpsialpdbskfsyyak",
"sk-tovmogiablsoeabwgqyvevpcfichyjpuzqdymmvksspdrtqt",
]
class APIKeyManager: