#!/usr/bin/env python # -*- coding: utf-8 -*- """ File: validate_excel_data_batch.py Description: 使用LLM批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写是否正确 """ import os import sys import pandas as pd import json import argparse import logging import concurrent.futures from tqdm import tqdm from dotenv import load_dotenv from langchain_openai import ChatOpenAI from pydantic import BaseModel, Field from langchain.output_parsers import PydanticOutputParser sys.path.append(os.getcwd()) from rag2_0.intent_recognition.PromptTemplates import classification_info from rag2_0.intent_recognition.DataModels import * from rag2_0.tool.ModelTool import OpenAiLLM # 定义验证结果的Pydantic模型 class ValidationResult(BaseModel): is_correct: bool = Field(description="验证是否通过") confidence_score: float = Field(description="置信度得分") reason: str = Field(default="", description="得出结论的原因") class ExcelDataValidator: """Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写""" def __init__(self, input_file=None, output_file=None, workers=4, debug=False): """ 初始化验证器 Args: input_file: 输入Excel文件路径 output_file: 输出结果Excel文件路径 workers: 并行工作线程数 debug: 是否启用调试模式(串行处理) """ # 加载环境变量 load_dotenv() self.input_file = input_file self.output_file = output_file self.workers = workers self.debug = debug self.df = None # 设置日志 self.setup_logging() def setup_logging(self): """配置日志输出""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler() ] ) logging.getLogger('httpx').setLevel(logging.WARNING) logging.getLogger('openai').setLevel(logging.WARNING) def load_data_from_excel(self, file_path=None): """ 从Excel文件中读取数据 Args: file_path: Excel文件路径,如不提供则使用初始化时的路径 Returns: DataFrame对象 """ file_path = file_path or self.input_file if not file_path: logging.error("未指定输入文件路径", exc_info=True) return None try: df = pd.read_excel(file_path) required_columns = ["问题", "问题分类", "问题改写", "槽位信息", "检索的内容"] for col in required_columns: if col not in df.columns: logging.error(f"缺少必要的列: {col}", exc_info=True) return None logging.info(f"成功从{file_path}读取了{len(df)}条数据") self.df = df return df except Exception as e: logging.error(f"读取Excel文件时出错: {e}", exc_info=True) return None def validate_classification(self, llm:OpenAiLLM , query:str, vertical_class:str, sub_class:str): """ 验证问题分类是否正确 Args: llm: LLM模型 query: 原始问题 vertical_class: 一级分类 sub_class: 二级分类 Returns: (bool, str, float): 是否正确,错误原因(如果有),置信度 """ parser = self.create_validation_parser() format_instructions = parser.get_format_instructions() prompt = f""" 背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用,也可能涉及电力造价专业知识。我对用户问题进行了分类,请评估以下问题分类是否正确。 我目前总共有以下分类: {classification_info} 问题的分类情况如下: 原始问题: {query} 一级分类: {vertical_class} 二级分类: {sub_class} 请从专业角度分析这个分类是否准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。 {format_instructions} """ try: response = llm.invoke(prompt) result = parser.parse(response.content) return result.is_correct, result.reason, result.confidence_score except Exception as e: logging.warning(f"验证问题分类时出错: {e}") return False, f"验证过程出错: {str(e)}", 0.0 def _get_slot_model(self, classification: Classification) -> Optional[type]: """ 根据分类结果获取对应的槽位模型类,用于统一提示词处理 Args: classification: 意图分类结果 Returns: 对应的槽位模型类 """ # 软件问题 if classification.vertical_classification == "软件问题": if classification.sub_classification == "软件功能": return SoftwareFunctionSlots elif classification.sub_classification == "故障排查": return SoftwareTroubleShootingSlots # 业务问题 elif classification.vertical_classification == "业务问题": if classification.sub_classification == "专业咨询": return ProfessionalConsultingSlots elif classification.sub_classification == "数据问题": return DataProblemSlots # 安装下载注册 elif classification.vertical_classification == "安装下载注册": if classification.sub_classification == "后缀名咨询": return FileExtensionConsultingSlots elif classification.sub_classification == "软件锁类": return SoftwareLockSlots elif classification.sub_classification == "安装下载类": return InstallationDownloadSlots elif classification.sub_classification == "问题排查类": return ProblemDiagnosisSlots # 其他 elif classification.vertical_classification == "其他": return OtherSlots return None def validate_slot(self, llm, rewrite, slot_info, vertical_class, sub_class): """ 验证槽位填充是否正确 Args: llm: LLM模型 rewrite: 问题改写 slot_info: 槽位信息(JSON字符串) Returns: (bool, str, float): 是否正确,错误原因(如果有),置信度 """ # 解析槽位信息JSON try: if isinstance(slot_info, str) and slot_info.strip(): slots = json.loads(slot_info) else: slots = slot_info except: slots = slot_info parser = self.create_validation_parser() format_instructions = parser.get_format_instructions() slot_info_prompt = self._get_slot_model(Classification(vertical_classification=vertical_class, sub_classification=sub_class)).model_json_schema() slot_info_prompt = json.dumps(slot_info_prompt, ensure_ascii=False) prompt = f""" 背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我从用户问题中提取了槽位信息,请评估这些槽位信息是否准确、完整。 问题改写: {rewrite} 槽位模板:{slot_info_prompt} 填充的槽位信息: {slots} 槽位信息应该准确提取问题中的关键实体和属性,如软件名称、功能名称、错误信息等。请分析这些槽位是否准确填充,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。 {format_instructions} """ try: response = llm.invoke(prompt) result = parser.parse(response.content) return result.is_correct, result.reason, result.confidence_score except Exception as e: logging.warning(f"验证槽位填充时出错: {e}") return False, f"验证过程出错: {str(e)}", 0.0 def validate_retrieve_content(self, llm, rewrite, retrieve_content): """ 验证检索内容是否正确 Args: llm: LLM模型 rewrite: 问题改写 retrieve_content: 检索内容(可能是JSON字符串或文本) Returns: (bool, str, float): 是否正确,错误原因(如果有),置信度 """ # 解析检索内容 try: if isinstance(retrieve_content, str) and retrieve_content.strip(): if retrieve_content.startswith('{') or retrieve_content.startswith('['): content = json.loads(retrieve_content) else: content = retrieve_content else: content = retrieve_content except: content = retrieve_content parser = self.create_validation_parser() format_instructions = parser.get_format_instructions() prompt = f""" 背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我针对用户问题检索了相关内容,请评估这些检索内容是否能解答提问。 问题改写: {rewrite} 检索内容: {content} 检索内容应该与问题主题相关,能够提供有用的信息来回答问题。请分析检索内容是否能解答提问、准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。 {format_instructions} """ try: response = llm.invoke(prompt) result = parser.parse(response.content) return result.is_correct, result.reason, result.confidence_score except Exception as e: logging.warning(f"验证检索内容时出错: {e}") return False, f"验证过程出错: {str(e)}", 0.0 def validate_rewrite(self, llm, query, rewrite): """ 验证问题改写是否正确 Args: llm: LLM模型 query: 原始问题 rewrite: 问题改写 Returns: (bool, str, float): 是否正确,错误原因(如果有),置信度 """ parser = self.create_validation_parser() format_instructions = parser.get_format_instructions() prompt = f""" 背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了改写,请评估以下问题改写是否正确。 原始问题: {query} 问题改写: {rewrite} 问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。 {format_instructions} """ try: response = llm.invoke(prompt) result = parser.parse(response.content) return result.is_correct, result.reason, result.confidence_score except Exception as e: logging.warning(f"验证问题改写时出错: {e}") return False, f"验证过程出错: {str(e)}", 0.0 def validate_row(self, llm, row_data): """ 按顺序验证一行数据中的各个环节 Args: llm: LLM模型 row_data: (index, row)元组 Returns: (index, is_all_correct, error_phase, error_reason, confidence_score): 行索引,是否全部正确,错误环节,错误原因,置信度 """ index, row = row_data query = row["问题"] query_class = row.get("问题分类", "") rewrite = row.get("问题改写", "") slot_info = row.get("槽位信息", "") retrieve_content = row.get("检索的内容", "") if self.debug: logging.info(f"开始验证行 {index}:") logging.info(f" 问题: {query}") logging.info(f" 问题分类: {query_class}") logging.info(f" 问题改写: {rewrite}") try: confidence_score = 0.0 # 1. 验证问题改写 if rewrite: if self.debug: logging.info(f" 验证问题改写...") result = self.validate_rewrite(llm, query, rewrite) if isinstance(result, tuple) and len(result) >= 3: is_correct, error_reason, rewrite_confidence = result[:3] confidence_score = max(confidence_score, rewrite_confidence) if self.debug: logging.info(f" 问题改写验证结果: {'通过' if is_correct else '不通过'}, 置信度: {rewrite_confidence:.2f}") if not is_correct: logging.info(f" 错误原因: {error_reason}") if not is_correct: return index, False, "问题改写", error_reason, rewrite_confidence # 2. 验证问题分类 if query_class: if self.debug: logging.info(f" 验证问题分类...") query_class_list = query_class.split(" - ") if len(query_class_list) >= 2: result = self.validate_classification(llm, rewrite, query_class_list[0], query_class_list[1]) if isinstance(result, tuple) and len(result) >= 3: is_correct, error_reason, classification_confidence = result[:3] confidence_score = max(confidence_score, classification_confidence) if self.debug: logging.info(f" 问题分类验证结果: {'通过' if is_correct else '不通过'}, 置信度: {classification_confidence:.2f}") if not is_correct: logging.info(f" 错误原因: {error_reason}") if not is_correct: return index, False, "问题分类", error_reason, classification_confidence # 3. 验证槽位填充 if slot_info: if self.debug: logging.info(f" 验证槽位填充...") result = self.validate_slot(llm, rewrite, slot_info, query_class_list[0], query_class_list[1]) if isinstance(result, tuple) and len(result) >= 3: is_correct, error_reason, slot_confidence = result[:3] confidence_score = max(confidence_score, slot_confidence) if self.debug: logging.info(f" 槽位填充验证结果: {'通过' if is_correct else '不通过'}, 置信度: {slot_confidence:.2f}") if not is_correct: logging.info(f" 错误原因: {error_reason}") if not is_correct: return index, False, "槽位填充", error_reason, slot_confidence # 4. 验证检索内容 if retrieve_content and retrieve_content != "" and pd.notna(retrieve_content): if self.debug: logging.info(f" 验证检索内容...") result = self.validate_retrieve_content(llm, query, retrieve_content) if isinstance(result, tuple) and len(result) >= 3: is_correct, error_reason, retrieve_confidence = result[:3] confidence_score = max(confidence_score, retrieve_confidence) if self.debug: logging.info(f" 检索内容验证结果: {'通过' if is_correct else '不通过'}, 置信度: {retrieve_confidence:.2f}") if not is_correct: logging.info(f" 错误原因: {error_reason}") if not is_correct: return index, False, "检索内容", error_reason, retrieve_confidence if self.debug: logging.info(f" 行 {index} 验证完成: 通过, 总置信度: {confidence_score:.2f}") return index, True, "", "", confidence_score except Exception as e: error_msg = f"处理行 {index} 时发生错误: {str(e)}" logging.error(error_msg, exc_info=True) return index, False, "处理错误", error_msg, 0.0 def create_llm_instances(self, count): """创建多个LLM实例""" api_key = os.getenv("OPENAI_API_KEY") base_url = os.getenv("OPENAI_API_BASE") model_name = "deepseek-ai/DeepSeek-R1" llm_params = {"temperature": 0.7, "model": model_name} if api_key: llm_params["api_key"] = api_key if base_url: llm_params["base_url"] = base_url return [OpenAiLLM(**llm_params) for _ in range(count)] def validate(self, input_file=None, output_file=None, workers=None, debug=None): """ 执行验证过程 Args: input_file: 输入Excel文件路径 output_file: 输出结果Excel文件路径 workers: 并行工作线程数 batch_size: 每批处理的行数(已弃用,保留参数保持兼容) debug: 是否启用调试模式(串行处理) Returns: 验证后的DataFrame """ input_file = input_file or self.input_file output_file = output_file or self.output_file workers = workers or self.workers debug = debug if debug is not None else self.debug # 读取数据 df = self.load_data_from_excel(input_file) if df is None: return None # 添加验证结果列 df["验证结果"] = "" df["错误环节"] = "" df["错误原因"] = "" df["置信度"] = 0.0 # 准备数据 all_rows = list(df.iterrows()) # 创建LLM实例 llm = self.create_llm_instances(1)[0] # 根据模式选择处理方式 all_results = [] if debug: # 调试模式:串行处理 logging.info("启用调试模式,使用串行处理...") for i, row_data in enumerate(all_rows): logging.info(f"处理第 {i+1}/{len(all_rows)} 行...") result = self.validate_row(llm, row_data) all_results.append(result) # 实时更新DataFrame index, is_correct, error_phase, error_reason, confidence_score = result df.at[index, "验证结果"] = "通过" if is_correct else "不通过" df.at[index, "错误环节"] = error_phase df.at[index, "错误原因"] = error_reason df.at[index, "置信度"] = confidence_score # 输出当前结果 logging.info(f"行 {index} 验证结果: {'通过' if is_correct else '不通过'}, 错误环节: {error_phase}, 错误原因: {error_reason}, 置信度: {confidence_score:.2f}") else: # 正常模式:并行处理,每行单独处理 llm_instances = self.create_llm_instances(min(workers, len(all_rows))) with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: # 为每行分配一个LLM实例 future_to_row = { executor.submit(self.validate_row, llm_instances[i % len(llm_instances)], row_data): i for i, row_data in enumerate(all_rows) } # 使用tqdm显示进度条 for future in tqdm(concurrent.futures.as_completed(future_to_row), total=len(all_rows), desc="处理进度"): result = future.result() all_results.append(result) # 按行索引排序结果,确保与原始数据顺序一致 all_results.sort(key=lambda x: x[0]) # 将结果填充到DataFrame for result in all_results: if len(result) >= 5: index, is_correct, error_phase, error_reason, confidence_score = result df.at[index, "验证结果"] = "通过" if is_correct else "不通过" df.at[index, "错误环节"] = error_phase df.at[index, "错误原因"] = error_reason df.at[index, "置信度"] = confidence_score else: index, is_correct, error_phase, error_reason = result df.at[index, "验证结果"] = "通过" if is_correct else "不通过" df.at[index, "错误环节"] = error_phase df.at[index, "错误原因"] = error_reason # 保存结果 if output_file is None: output_file = os.path.join( os.path.dirname(input_file), f"validated_{os.path.basename(input_file)}" ) df.to_excel(output_file, index=False) logging.info(f"验证完成,结果已保存至: {output_file}") # 输出统计信息 self.print_statistics(df) return df def print_statistics(self, df): """打印统计信息""" total = len(df) passed = len(df[df["验证结果"] == "通过"]) error_stats = df[df["验证结果"] == "不通过"]["错误环节"].value_counts() logging.info(f"统计信息: 总计 {total} 条, 通过 {passed} 条, 通过率 {passed/total*100:.2f}%") logging.info("错误环节统计:") for phase, count in error_stats.items(): logging.info(f"- {phase}: {count} 条") def create_validation_parser(self): """创建验证结果解析器""" return PydanticOutputParser(pydantic_object=ValidationResult) def main(): """主函数""" # 解析命令行参数 input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "1500条点踩软件问题测试_意图分类.xlsx") output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx") parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写") parser.add_argument("--input", "-i", type=str, help="输入Excel文件路径", default=input_excel) parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel) parser.add_argument("--workers", "-w", type=int, default=20, help="并行工作线程数") logging.info(f"输入文件路径: {args.input}, 输出文件路径: {args.output}, 并行工作线程数: {args.workers}") args = parser.parse_args() is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None # 创建验证器实例并执行验证 validator = ExcelDataValidator( input_file=args.input, output_file=args.output, workers=args.workers, debug=is_debug ) validator.validate() if __name__ == "__main__": main()