#!/usr/bin/env python # -*- coding: utf-8 -*- import os import json from concurrent.futures import ThreadPoolExecutor, as_completed from rag2_0.dify.dify_client import ChatClient, DifyClient from rag2_0.dify.dify_tool import DifyTool from pydantic import BaseModel, Field from langchain.output_parsers import PydanticOutputParser from threading import Lock class ContentSource(BaseModel): score: int = Field(description="相关性分数") reason: str = Field(description="评分理由") class BaseWorkflowChat: """ 工作流对话基类,封装了与Dify API交互的基本功能 """ def __init__(self, api_key: str, base_url: str): """ 初始化工作流对话基类 Args: api_key: Dify API的密钥 base_url: Dify API的基础URL """ self.chat_client = ChatClient(api_key=api_key, base_url=base_url) self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource) def create_chat_message(self, query: str): """ 创建聊天消息 Args: query: 问题内容 Returns: tuple: (聊天响应, 消息ID) """ try: response = self.chat_client.create_chat_message(inputs={}, query=query, user="AutoTestDifyChat").json() return response, response["message_id"] except Exception as e: raise e def calculate_score(self, query: str, content: str) -> int: """ 使用LLM判断query与content之间的相关性分数 Args: query (str): 用户问题 content (str): 检索内容 Returns: int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败 """ from rag2_0.tool.ModelTool import OpenAiLLM try: prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。 【评分标准】 10分:完全契合,主题/意图完全一致且涵盖所有关键信息 8-9分:高度相关,核心要素匹配但存在少量信息缺失 6-7分:部分相关,涉及相同主题但存在重要信息缺失 4-5分:弱相关,仅次要信息点匹配 1-3分:完全不相关或信息冲突 【评估维度】 1. 主题一致性:核心主题/意图的匹配程度 2. 内容覆盖度:是否涵盖query的关键要素 3. 信息准确性:是否存在矛盾/错误信息 4. 细节丰富度:是否提供query要求的详细信息 【输出格式】 {{ "score": 评分, "reason": "简明扼要的评分理由(中文)" }} 【示例】 query: "新冠疫苗的常见副作用" 内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)" 输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}} 现在评估: query: "{query}" content: "{content}" """ api_key = os.getenv("OPENAI_API_KEY") base_url = os.getenv("OPENAI_API_BASE") model = os.getenv("LLM_MODEL_NAME") llm = OpenAiLLM(api_key=api_key, base_url=base_url, model=model) response = llm.invoke(user_prompt=prompt, need_retry=True) # 解析JSON响应 try: parsed_output = self.content_source_parser.parse(response.content) return parsed_output.score except Exception as e: return -1 except Exception as e: return -1 def get_retrieve_info(self, query: str, outputs: dict) -> tuple: """ 获取检索信息并计算分数 Args: query (str): 用户问题 outputs (dict): 检索输出结果 Returns: tuple: (检索内容列表, 最高分, 最低分, 平均分) """ max_score = 0 min_score = 10 total_score = 0 valid_scores = 0 retrieve_content = [] # 使用线程池并发计算分数 with ThreadPoolExecutor() as executor: # 创建任务列表 future_to_content = {} for result in outputs["result"]: content = result["content"].strip() future = executor.submit(self.calculate_score, query=query, content=content) future_to_content[future] = content # 收集结果 for future in as_completed(future_to_content): content = future_to_content[future] score = future.result() content_title = content.split("\n")[0] if score != -1: max_score = max(max_score, score) min_score = min(min_score, score) total_score += score valid_scores += 1 if content_title: retrieve_content.append(content_title + f"--相关性得分({score}分)") avg_score = total_score / valid_scores if valid_scores > 0 else 0 return retrieve_content, max_score, min_score, avg_score class NewWorkflowChat(BaseWorkflowChat): """ 新工作流对话类,用于调用新工作流发送对话并解析获取相关数据 """ def process_question(self, query: str) -> dict: """ 处理问题,获取新工作流的回答和相关信息 Args: query: 问题内容 Returns: dict: 包含问题、回答和相关信息的字典 """ response, message_id = self.create_chat_message(query) if isinstance(response, str) and response.startswith("error:"): raise RuntimeError(f"create_chat_message 出错:{response}") answer = response["answer"] workflow_info = self.get_workflow_info(query, message_id) if workflow_info is None: return None result = { "问题": query, "新流程答案": answer, "新问题改写": workflow_info["问题改写"], "新问题分类": workflow_info["问题分类"], "槽点信息": workflow_info["槽点信息"], "新检索词条": workflow_info["检索词条"], "检索内容": workflow_info["检索内容"], "message_id":message_id } return result def get_workflow_info(self, query: str, message_id: str) -> dict: """ 获取新工作流的问题分类和检索信息 Args: query (str): 用户问题 message_id (str): 新工作流的消息ID Returns: dict: 包含问题分类结果的字典 """ retrieve_title = [] retrieve_content = [] max_score = 0 min_score = 0 avg_score = 0 rewrite_query = "" vertical_classification = "" sub_classification = "" slot_info = "" try: message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id) for workflow_node in message_info["workflow_node_executions_info"]: if workflow_node["title"] == "知识检索结果后处理": outputs = json.loads(workflow_node["outputs"]) retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs) retrieve_content = outputs["result"] elif workflow_node["title"] == "问题优化结果解析": outputs = json.loads(workflow_node["outputs"]) rewrite_query = outputs["optimize_query"] llm_result_json = json.loads(workflow_node['inputs'])["llm_result"] json_result = json.loads(llm_result_json) vertical_classification = json_result['vertical_classification'] sub_classification = json_result['sub_classification'] slot_info = json.dumps(json_result["slot_filling"], ensure_ascii=False, indent=2) except Exception as e: raise e return { "问题改写": rewrite_query, "检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库", "检索内容": retrieve_content, "问题分类": f"{vertical_classification} - {sub_classification}", "槽点信息": slot_info, } class OldWorkFlowChat(BaseWorkflowChat): """ 旧工作流对话类,用于调用旧工作流发送对话并解析获取相关数据 """ def process_question(self, query: str) -> dict: """ 处理问题,获取旧工作流的回答和相关信息 Args: query: 问题内容 Returns: dict: 包含问题、回答和相关信息的字典 """ response, message_id = self.create_chat_message(query) if isinstance(response, str) and response.startswith("error:"): return None answer = response["answer"] workflow_info = self.get_workflow_info(query, message_id) if workflow_info is None: return None result = { "问题": query, "旧流程答案": answer, "旧问题改写": workflow_info["问题改写"], "旧检索词条": workflow_info["检索词条"], "检索内容": workflow_info["检索内容"], "message_id":message_id } return result def get_workflow_info(self, query: str, message_id: str) -> dict: """ 获取旧工作流的问题改写和检索信息 Args: query (str): 用户问题 message_id (str): 旧工作流的消息ID Returns: dict: 包含问题改写和检索信息的字典 """ retrieve_title = [] retrieve_content = [] max_score = 0 min_score = 0 avg_score = 0 rewrite_query = "" try: message_info = DifyTool.get_message_debug_info_by_id(message_id=message_id) for workflow_node in message_info["workflow_node_executions_info"]: if workflow_node["title"] == "知识检索结果后处理": outputs = json.loads(workflow_node["outputs"]) retrieve_title, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs) retrieve_content = outputs["result"] elif workflow_node["title"] == "问题优化结果解析": outputs = json.loads(workflow_node["outputs"]) rewrite_query = outputs["optimize_query"] except Exception as e: return None return { "问题改写": rewrite_query, "检索词条": "\n".join(retrieve_title) if retrieve_title else "未检索知识库", "检索内容": retrieve_content, }