#!/usr/bin/env python # -*- coding: utf-8 -*- """ File: validate_excel_data_batch.py Description: 使用LLM批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写是否正确 """ import os import pandas as pd import json import argparse import logging import concurrent.futures from tqdm import tqdm from dotenv import load_dotenv from langchain_openai import ChatOpenAI from rag2_0.intent_recognition.PromptTemplates import classification from rag2_0.tool.ModelTool import OpenAiLLM class ExcelDataValidator: """Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写""" def __init__(self, input_file=None, output_file=None, workers=4, batch_size=10): """ 初始化验证器 Args: input_file: 输入Excel文件路径 output_file: 输出结果Excel文件路径 workers: 并行工作线程数 batch_size: 每批处理的行数 """ # 加载环境变量 load_dotenv() self.input_file = input_file self.output_file = output_file self.workers = workers self.batch_size = batch_size self.df = None # 设置日志 self.setup_logging() def setup_logging(self): """配置日志输出""" logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.StreamHandler() ] ) logging.getLogger('httpx').setLevel(logging.WARNING) logging.getLogger('openai').setLevel(logging.WARNING) def load_data_from_excel(self, file_path=None): """ 从Excel文件中读取数据 Args: file_path: Excel文件路径,如不提供则使用初始化时的路径 Returns: DataFrame对象 """ file_path = file_path or self.input_file if not file_path: logging.error("未指定输入文件路径") return None try: df = pd.read_excel(file_path) required_columns = ["提问", "问题拆解", "一级分类", "二级分类", "问题改写", "检索的关键词"] for col in required_columns: if col not in df.columns: logging.error(f"缺少必要的列: {col}") return None logging.info(f"成功从{file_path}读取了{len(df)}条数据") self.df = df return df except Exception as e: logging.error(f"读取Excel文件时出错: {e}") return None def validate_classification(self, llm, query, vertical_class, sub_class): """ 验证问题分类是否正确 Args: llm: LLM模型 query: 原始问题 vertical_class: 一级分类 sub_class: 二级分类 Returns: (bool, str): 是否正确,错误原因(如果有) """ prompt = f""" 背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用,也可能涉及电力造价专业知识。我对用户问题进行了分类,请评估以下问题分类是否正确。 我目前总共有以下分类: {classification} 问题的分类情况如下: 原始问题: {query} 一级分类: {vertical_class} 二级分类: {sub_class} 请从专业角度分析这个分类是否准确。只需返回"正确"或"错误:原因",不需要其他解释。""" try: response = llm.invoke(prompt) result = response.content.strip() if result.startswith("正确"): return True, "" else: error_reason = result.replace("错误:", "").strip() if "错误:" in result else result return False, error_reason except Exception as e: logging.warning(f"验证问题分类时出错: {e}") return False, f"验证过程出错: {str(e)}" def validate_query_keys(self, llm, query, query_keys): """ 验证问题拆解是否正确 Args: llm: LLM模型 query: 原始问题 query_keys: 问题拆解 Returns: (bool, str): 是否正确,错误原因(如果有) """ prompt = f""" 背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了拆解,请评估以下问题拆解是否正确。 原始问题: {query} 问题拆解: {query_keys} 问题拆解应该准确提取原始问题中的关键词和信息。请分析这个拆解是否准确。只需返回"正确"或"错误:原因",不需要其他解释。""" try: response = llm.invoke(prompt) result = response.content.strip() if result.startswith("正确"): return True, "" else: error_reason = result.replace("错误:", "").strip() if "错误:" in result else result return False, error_reason except Exception as e: logging.warning(f"验证问题拆解时出错: {e}") return False, f"验证过程出错: {str(e)}" def validate_keywords(self, llm, query, query_keys, keywords_str): """ 验证检索关键词是否准确 Args: llm: LLM模型 query: 原始问题 query_keys: 问题拆解 keywords_str: 检索关键词(JSON字符串) Returns: (bool, str): 是否正确,错误原因(如果有) """ # 解析关键词JSON try: if isinstance(keywords_str, str) and keywords_str.strip(): keywords = json.loads(keywords_str) else: keywords = [] except: keywords = keywords_str prompt = f""" 背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。通过问题检索出了一些关键词,请评估这些关键词是否准确,是否与问题相关 原始问题: {query} 问题拆解: {query_keys} 检索关键词: {keywords} 检索关键词应该准确反映问题中需要检索的关键概念和术语。请分析这些关键词是否准确、完整。只需返回"正确"或"错误:原因",不需要其他解释。""" try: response = llm.invoke(prompt) result = response.content.strip() if result.startswith("正确"): return True, "" else: error_reason = result.replace("错误:", "").strip() if "错误:" in result else result return False, error_reason except Exception as e: logging.warning(f"验证检索关键词时出错: {e}") return False, f"验证过程出错: {str(e)}" def validate_rewrite(self, llm, query, rewrite): """ 验证问题改写是否正确 Args: llm: LLM模型 query: 原始问题 rewrite: 问题改写 Returns: (bool, str): 是否正确,错误原因(如果有) """ prompt = f""" 背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了改写,请评估以下问题改写是否正确。 原始问题: {query} 问题改写: {rewrite} 问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确。只需返回"正确"或"错误:原因",不需要其他解释。""" try: response = llm.invoke(prompt) result = response.content.strip() if result.startswith("正确"): return True, "" else: error_reason = result.replace("错误:", "").strip() if "错误:" in result else result return False, error_reason except Exception as e: logging.warning(f"验证问题改写时出错: {e}") return False, f"验证过程出错: {str(e)}" def validate_row(self, llm, row_data): """ 按顺序验证一行数据中的各个环节 Args: llm: LLM模型 row_data: (index, row)元组 Returns: (index, is_all_correct, error_phase, error_reason): 行索引,是否全部正确,错误环节,错误原因 """ index, row = row_data query = row["提问"] query_keys = row["问题拆解"] vertical_class = row["一级分类"] sub_class = row["二级分类"] rewrite = row["问题改写"] keywords_str = row["检索的关键词"] try: # 1. 验证问题分类 is_correct, error_reason = self.validate_classification(llm, query, vertical_class, sub_class) if not is_correct: return index, False, "问题分类", error_reason # 2. 验证问题拆解 is_correct, error_reason = self.validate_query_keys(llm, query, query_keys) if not is_correct: return index, False, "问题拆解", error_reason # 3. 验证检索关键词 is_correct, error_reason = self.validate_keywords(llm, query, query_keys, keywords_str) if not is_correct: return index, False, "关键词检索", error_reason # 4. 验证问题改写 is_correct, error_reason = self.validate_rewrite(llm, query, rewrite) if not is_correct: return index, False, "问题改写", error_reason return index, True, "", "" except Exception as e: error_msg = f"处理行 {index} 时发生错误: {str(e)}" logging.error(error_msg) return index, False, "处理错误", error_msg def process_batch(self, llm, batch_data): """处理一批数据""" results = [] for row_data in batch_data: results.append(self.validate_row(llm, row_data)) return results def create_llm_instances(self, count): """创建多个LLM实例""" api_key = os.getenv("OPENAI_API_KEY") base_url = os.getenv("OPENAI_API_BASE") model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo") llm_params = {"temperature": 0.7, "model": model_name} if api_key: llm_params["api_key"] = api_key if base_url: llm_params["base_url"] = base_url return [OpenAiLLM(**llm_params) for _ in range(count)] def validate(self, input_file=None, output_file=None, workers=None, batch_size=None): """ 执行验证过程 Args: input_file: 输入Excel文件路径 output_file: 输出结果Excel文件路径 workers: 并行工作线程数 batch_size: 每批处理的行数 Returns: 验证后的DataFrame """ input_file = input_file or self.input_file output_file = output_file or self.output_file workers = workers or self.workers batch_size = batch_size or self.batch_size # 读取数据 df = self.load_data_from_excel(input_file) if df is None: return None # 添加验证结果列 df["验证结果"] = "" df["错误环节"] = "" df["错误原因"] = "" # 准备数据批次 all_rows = list(df.iterrows()) batches = [all_rows[i:i+batch_size] for i in range(0, len(all_rows), batch_size)] # 创建多个LLM实例 llm_instances = self.create_llm_instances(min(workers, len(batches))) # 使用线程池处理数据 all_results = [] with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: # 为每个批次分配一个LLM实例 future_to_batch = { executor.submit(self.process_batch, llm_instances[i % len(llm_instances)], batch): i for i, batch in enumerate(batches) } # 使用tqdm显示进度条 for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches), desc="批次处理进度"): batch_results = future.result() all_results.extend(batch_results) # 按行索引排序结果,确保与原始数据顺序一致 all_results.sort(key=lambda x: x[0]) # 将结果填充到DataFrame for index, is_correct, error_phase, error_reason in all_results: df.at[index, "验证结果"] = "通过" if is_correct else "不通过" df.at[index, "错误环节"] = error_phase df.at[index, "错误原因"] = error_reason # 保存结果 if output_file is None: output_file = os.path.join( os.path.dirname(input_file), f"validated_{os.path.basename(input_file)}" ) df.to_excel(output_file, index=False) logging.info(f"验证完成,结果已保存至: {output_file}") # 输出统计信息 self.print_statistics(df) return df def print_statistics(self, df): """打印统计信息""" total = len(df) passed = len(df[df["验证结果"] == "通过"]) error_stats = df[df["验证结果"] == "不通过"]["错误环节"].value_counts() logging.info(f"统计信息: 总计 {total} 条, 通过 {passed} 条, 通过率 {passed/total*100:.2f}%") logging.info("错误环节统计:") for phase, count in error_stats.items(): logging.info(f"- {phase}: {count} 条") def main(): """主函数""" # 解析命令行参数 input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "问题分类重写结果") output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx") parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写") parser.add_argument("--input", "-i", type=str, required=True, help="输入Excel文件路径", default=input_excel) parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel) parser.add_argument("--workers", "-w", type=int, default=2, help="并行工作线程数") parser.add_argument("--batch-size", "-b", type=int, default=5, help="每批处理的行数") args = parser.parse_args() # 创建验证器实例并执行验证 validator = ExcelDataValidator( input_file=args.input, output_file=args.output, workers=args.workers, batch_size=args.batch_size ) validator.validate() if __name__ == "__main__": main()