589 lines
24 KiB
Python
Executable File
589 lines
24 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
File: validate_excel_data_batch.py
|
|
Description: 使用LLM批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写是否正确
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
import pandas as pd
|
|
import json
|
|
import argparse
|
|
import logging
|
|
import concurrent.futures
|
|
from tqdm import tqdm
|
|
from dotenv import load_dotenv
|
|
from langchain_openai import ChatOpenAI
|
|
from pydantic import BaseModel, Field
|
|
from langchain.output_parsers import PydanticOutputParser
|
|
|
|
sys.path.append(os.getcwd())
|
|
from rag2_0.intent_recognition.PromptTemplates import classification_info
|
|
from rag2_0.intent_recognition.DataModels import *
|
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
|
|
|
|
|
# 定义验证结果的Pydantic模型
|
|
class ValidationResult(BaseModel):
|
|
is_correct: bool = Field(description="验证是否通过")
|
|
confidence_score: float = Field(description="置信度得分")
|
|
reason: str = Field(default="", description="得出结论的原因")
|
|
|
|
class ExcelDataValidator:
|
|
"""Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写"""
|
|
|
|
def __init__(self, input_file=None, output_file=None, workers=4, batch_size=10, debug=False):
|
|
"""
|
|
初始化验证器
|
|
|
|
Args:
|
|
input_file: 输入Excel文件路径
|
|
output_file: 输出结果Excel文件路径
|
|
workers: 并行工作线程数
|
|
batch_size: 每批处理的行数
|
|
debug: 是否启用调试模式(串行处理)
|
|
"""
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
self.input_file = input_file
|
|
self.output_file = output_file
|
|
self.workers = workers
|
|
self.batch_size = batch_size
|
|
self.debug = debug
|
|
self.df = None
|
|
|
|
# 设置日志
|
|
self.setup_logging()
|
|
|
|
def setup_logging(self):
|
|
"""配置日志输出"""
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler()
|
|
]
|
|
)
|
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
|
|
|
def load_data_from_excel(self, file_path=None):
|
|
"""
|
|
从Excel文件中读取数据
|
|
|
|
Args:
|
|
file_path: Excel文件路径,如不提供则使用初始化时的路径
|
|
|
|
Returns:
|
|
DataFrame对象
|
|
"""
|
|
file_path = file_path or self.input_file
|
|
if not file_path:
|
|
logging.error("未指定输入文件路径")
|
|
return None
|
|
|
|
try:
|
|
df = pd.read_excel(file_path)
|
|
required_columns = ["问题", "问题分类", "问题改写", "槽点信息", "检索的内容"]
|
|
for col in required_columns:
|
|
if col not in df.columns:
|
|
logging.error(f"缺少必要的列: {col}")
|
|
return None
|
|
logging.info(f"成功从{file_path}读取了{len(df)}条数据")
|
|
self.df = df
|
|
return df
|
|
except Exception as e:
|
|
logging.error(f"读取Excel文件时出错: {e}")
|
|
return None
|
|
|
|
def validate_classification(self, llm, query, vertical_class, sub_class):
|
|
"""
|
|
验证问题分类是否正确
|
|
|
|
Args:
|
|
llm: LLM模型
|
|
query: 原始问题
|
|
vertical_class: 一级分类
|
|
sub_class: 二级分类
|
|
|
|
Returns:
|
|
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
|
"""
|
|
parser = self.create_validation_parser()
|
|
format_instructions = parser.get_format_instructions()
|
|
|
|
prompt = f"""
|
|
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用,也可能涉及电力造价专业知识。我对用户问题进行了分类,请评估以下问题分类是否正确。
|
|
|
|
我目前总共有以下分类:
|
|
{classification_info}
|
|
|
|
问题的分类情况如下:
|
|
原始问题: {query}
|
|
一级分类: {vertical_class}
|
|
二级分类: {sub_class}
|
|
|
|
请从专业角度分析这个分类是否准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
|
|
|
{format_instructions}
|
|
"""
|
|
|
|
try:
|
|
response = llm.invoke(prompt)
|
|
result = parser.parse(response.content)
|
|
return result.is_correct, result.reason, result.confidence_score
|
|
except Exception as e:
|
|
logging.warning(f"验证问题分类时出错: {e}")
|
|
return False, f"验证过程出错: {str(e)}", 0.0
|
|
|
|
def _get_slot_model(self, classification: Classification) -> Optional[type]:
|
|
"""
|
|
根据分类结果获取对应的槽位模型类,用于统一提示词处理
|
|
|
|
Args:
|
|
classification: 意图分类结果
|
|
|
|
Returns:
|
|
对应的槽位模型类
|
|
"""
|
|
# 软件问题
|
|
if classification.vertical_classification == "软件问题":
|
|
if classification.sub_classification == "软件功能":
|
|
return SoftwareFunctionSlots
|
|
elif classification.sub_classification == "故障排查":
|
|
return SoftwareTroubleShootingSlots
|
|
|
|
# 业务问题
|
|
elif classification.vertical_classification == "业务问题":
|
|
if classification.sub_classification == "专业咨询":
|
|
return ProfessionalConsultingSlots
|
|
elif classification.sub_classification == "数据问题":
|
|
return DataProblemSlots
|
|
|
|
# 安装下载注册
|
|
elif classification.vertical_classification == "安装下载注册":
|
|
if classification.sub_classification == "后缀名咨询":
|
|
return FileExtensionConsultingSlots
|
|
elif classification.sub_classification == "软件锁类":
|
|
return SoftwareLockSlots
|
|
elif classification.sub_classification == "安装下载类":
|
|
return InstallationDownloadSlots
|
|
elif classification.sub_classification == "问题排查类":
|
|
return ProblemDiagnosisSlots
|
|
|
|
# 其他
|
|
elif classification.vertical_classification == "其他":
|
|
return OtherSlots
|
|
|
|
return None
|
|
|
|
def validate_slot(self, llm, rewrite, slot_info, vertical_class, sub_class):
|
|
"""
|
|
验证槽位填充是否正确
|
|
|
|
Args:
|
|
llm: LLM模型
|
|
rewrite: 问题改写
|
|
slot_info: 槽位信息(JSON字符串)
|
|
|
|
Returns:
|
|
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
|
"""
|
|
# 解析槽位信息JSON
|
|
try:
|
|
if isinstance(slot_info, str) and slot_info.strip():
|
|
slots = json.loads(slot_info)
|
|
else:
|
|
slots = slot_info
|
|
except:
|
|
slots = slot_info
|
|
|
|
parser = self.create_validation_parser()
|
|
format_instructions = parser.get_format_instructions()
|
|
slot_info_prompt = self._get_slot_model(Classification(vertical_classification=vertical_class, sub_classification=sub_class)).model_json_schema()
|
|
slot_info_prompt = json.dumps(slot_info_prompt, ensure_ascii=False)
|
|
prompt = f"""
|
|
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我从用户问题中提取了槽位信息,请评估这些槽位信息是否准确、完整。
|
|
|
|
问题改写: {rewrite}
|
|
槽位模板:{slot_info_prompt}
|
|
|
|
填充的槽位信息: {slots}
|
|
|
|
槽位信息应该准确提取问题中的关键实体和属性,如软件名称、功能名称、错误信息等。请分析这些槽位是否准确填充,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
|
|
|
{format_instructions}
|
|
"""
|
|
|
|
try:
|
|
response = llm.invoke(prompt)
|
|
result = parser.parse(response.content)
|
|
return result.is_correct, result.reason, result.confidence_score
|
|
except Exception as e:
|
|
logging.warning(f"验证槽位填充时出错: {e}")
|
|
return False, f"验证过程出错: {str(e)}", 0.0
|
|
|
|
def validate_retrieve_content(self, llm, rewrite, retrieve_content):
|
|
"""
|
|
验证检索内容是否正确
|
|
|
|
Args:
|
|
llm: LLM模型
|
|
rewrite: 问题改写
|
|
retrieve_content: 检索内容(可能是JSON字符串或文本)
|
|
|
|
Returns:
|
|
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
|
"""
|
|
# 解析检索内容
|
|
try:
|
|
if isinstance(retrieve_content, str) and retrieve_content.strip():
|
|
if retrieve_content.startswith('{') or retrieve_content.startswith('['):
|
|
content = json.loads(retrieve_content)
|
|
else:
|
|
content = retrieve_content
|
|
else:
|
|
content = retrieve_content
|
|
except:
|
|
content = retrieve_content
|
|
|
|
parser = self.create_validation_parser()
|
|
format_instructions = parser.get_format_instructions()
|
|
|
|
prompt = f"""
|
|
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我针对用户问题检索了相关内容,请评估这些检索内容是否与问题相关、是否准确。
|
|
|
|
问题改写: {rewrite}
|
|
检索内容: {content}
|
|
|
|
检索内容应该与问题主题相关,能够提供有用的信息来回答问题。请分析检索内容是否相关、准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
|
|
|
{format_instructions}
|
|
"""
|
|
|
|
try:
|
|
response = llm.invoke(prompt)
|
|
result = parser.parse(response.content)
|
|
return result.is_correct, result.reason, result.confidence_score
|
|
except Exception as e:
|
|
logging.warning(f"验证检索内容时出错: {e}")
|
|
return False, f"验证过程出错: {str(e)}", 0.0
|
|
|
|
def validate_rewrite(self, llm, query, rewrite):
|
|
"""
|
|
验证问题改写是否正确
|
|
|
|
Args:
|
|
llm: LLM模型
|
|
query: 原始问题
|
|
rewrite: 问题改写
|
|
|
|
Returns:
|
|
(bool, str, float): 是否正确,错误原因(如果有),置信度
|
|
"""
|
|
parser = self.create_validation_parser()
|
|
format_instructions = parser.get_format_instructions()
|
|
|
|
prompt = f"""
|
|
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了改写,请评估以下问题改写是否正确。
|
|
|
|
原始问题: {query}
|
|
问题改写: {rewrite}
|
|
|
|
问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
|
|
|
|
{format_instructions}
|
|
"""
|
|
|
|
try:
|
|
response = llm.invoke(prompt)
|
|
result = parser.parse(response.content)
|
|
return result.is_correct, result.reason, result.confidence_score
|
|
except Exception as e:
|
|
logging.warning(f"验证问题改写时出错: {e}")
|
|
return False, f"验证过程出错: {str(e)}", 0.0
|
|
|
|
def validate_row(self, llm, row_data):
|
|
"""
|
|
按顺序验证一行数据中的各个环节
|
|
|
|
Args:
|
|
llm: LLM模型
|
|
row_data: (index, row)元组
|
|
|
|
Returns:
|
|
(index, is_all_correct, error_phase, error_reason, confidence_score): 行索引,是否全部正确,错误环节,错误原因,置信度
|
|
"""
|
|
index, row = row_data
|
|
query = row["问题"]
|
|
query_class = row.get("问题分类", "")
|
|
rewrite = row.get("问题改写", "")
|
|
slot_info = row.get("槽点信息", "")
|
|
retrieve_content = row.get("检索的内容", "")
|
|
|
|
if self.debug:
|
|
logging.info(f"开始验证行 {index}:")
|
|
logging.info(f" 问题: {query}")
|
|
logging.info(f" 问题分类: {query_class}")
|
|
logging.info(f" 问题改写: {rewrite}")
|
|
|
|
try:
|
|
|
|
confidence_score = 0.0
|
|
# 1. 验证问题改写
|
|
if rewrite:
|
|
if self.debug:
|
|
logging.info(f" 验证问题改写...")
|
|
|
|
result = self.validate_rewrite(llm, query, rewrite)
|
|
if isinstance(result, tuple) and len(result) >= 3:
|
|
is_correct, error_reason, rewrite_confidence = result[:3]
|
|
confidence_score = max(confidence_score, rewrite_confidence)
|
|
|
|
if self.debug:
|
|
logging.info(f" 问题改写验证结果: {'通过' if is_correct else '不通过'}, 置信度: {rewrite_confidence:.2f}")
|
|
if not is_correct:
|
|
logging.info(f" 错误原因: {error_reason}")
|
|
|
|
if not is_correct:
|
|
return index, False, "问题改写", error_reason, rewrite_confidence
|
|
|
|
# 2. 验证问题分类
|
|
if query_class:
|
|
if self.debug:
|
|
logging.info(f" 验证问题分类...")
|
|
|
|
query_class_list = query_class.split(" - ")
|
|
if len(query_class_list) >= 2:
|
|
result = self.validate_classification(llm, rewrite, query_class_list[0], query_class_list[1])
|
|
if isinstance(result, tuple) and len(result) >= 3:
|
|
is_correct, error_reason, confidence_score = result[:3]
|
|
|
|
if self.debug:
|
|
logging.info(f" 问题分类验证结果: {'通过' if is_correct else '不通过'}, 置信度: {confidence_score:.2f}")
|
|
if not is_correct:
|
|
logging.info(f" 错误原因: {error_reason}")
|
|
|
|
if not is_correct:
|
|
return index, False, "问题分类", error_reason, confidence_score
|
|
|
|
|
|
|
|
# 3. 验证槽位填充
|
|
if slot_info:
|
|
if self.debug:
|
|
logging.info(f" 验证槽位填充...")
|
|
|
|
result = self.validate_slot(llm, rewrite, slot_info, query_class_list[0], query_class_list[1])
|
|
if isinstance(result, tuple) and len(result) >= 3:
|
|
is_correct, error_reason, slot_confidence = result[:3]
|
|
confidence_score = max(confidence_score, slot_confidence)
|
|
|
|
if self.debug:
|
|
logging.info(f" 槽位填充验证结果: {'通过' if is_correct else '不通过'}, 置信度: {slot_confidence:.2f}")
|
|
if not is_correct:
|
|
logging.info(f" 错误原因: {error_reason}")
|
|
|
|
if not is_correct:
|
|
return index, False, "槽位填充", error_reason, slot_confidence
|
|
|
|
# 4. 验证检索内容
|
|
if retrieve_content:
|
|
if self.debug:
|
|
logging.info(f" 验证检索内容...")
|
|
|
|
result = self.validate_retrieve_content(llm, rewrite, retrieve_content)
|
|
if isinstance(result, tuple) and len(result) >= 3:
|
|
is_correct, error_reason, retrieve_confidence = result[:3]
|
|
confidence_score = max(confidence_score, retrieve_confidence)
|
|
|
|
if self.debug:
|
|
logging.info(f" 检索内容验证结果: {'通过' if is_correct else '不通过'}, 置信度: {retrieve_confidence:.2f}")
|
|
if not is_correct:
|
|
logging.info(f" 错误原因: {error_reason}")
|
|
|
|
if not is_correct:
|
|
return index, False, "检索内容", error_reason, retrieve_confidence
|
|
|
|
if self.debug:
|
|
logging.info(f" 行 {index} 验证完成: 通过, 总置信度: {confidence_score:.2f}")
|
|
|
|
return index, True, "", "", confidence_score
|
|
except Exception as e:
|
|
error_msg = f"处理行 {index} 时发生错误: {str(e)}"
|
|
logging.error(error_msg)
|
|
if self.debug:
|
|
import traceback
|
|
logging.error(traceback.format_exc())
|
|
return index, False, "处理错误", error_msg, 0.0
|
|
|
|
def process_batch(self, llm, batch_data):
|
|
"""处理一批数据"""
|
|
results = []
|
|
for row_data in batch_data:
|
|
results.append(self.validate_row(llm, row_data))
|
|
return results
|
|
|
|
def create_llm_instances(self, count):
|
|
"""创建多个LLM实例"""
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
base_url = os.getenv("OPENAI_API_BASE")
|
|
model_name = "deepseek-ai/DeepSeek-R1"
|
|
|
|
llm_params = {"temperature": 0.7, "model": model_name}
|
|
if api_key:
|
|
llm_params["api_key"] = api_key
|
|
if base_url:
|
|
llm_params["base_url"] = base_url
|
|
|
|
return [OpenAiLLM(**llm_params) for _ in range(count)]
|
|
|
|
def validate(self, input_file=None, output_file=None, workers=None, batch_size=None, debug=None):
|
|
"""
|
|
执行验证过程
|
|
|
|
Args:
|
|
input_file: 输入Excel文件路径
|
|
output_file: 输出结果Excel文件路径
|
|
workers: 并行工作线程数
|
|
batch_size: 每批处理的行数
|
|
debug: 是否启用调试模式(串行处理)
|
|
|
|
Returns:
|
|
验证后的DataFrame
|
|
"""
|
|
input_file = input_file or self.input_file
|
|
output_file = output_file or self.output_file
|
|
workers = workers or self.workers
|
|
batch_size = batch_size or self.batch_size
|
|
debug = debug if debug is not None else self.debug
|
|
|
|
# 读取数据
|
|
df = self.load_data_from_excel(input_file)
|
|
if df is None:
|
|
return None
|
|
|
|
# 添加验证结果列
|
|
df["验证结果"] = ""
|
|
df["错误环节"] = ""
|
|
df["错误原因"] = ""
|
|
df["置信度"] = 0.0
|
|
|
|
# 准备数据
|
|
all_rows = list(df.iterrows())
|
|
|
|
# 创建LLM实例
|
|
llm = self.create_llm_instances(1)[0]
|
|
|
|
# 根据模式选择处理方式
|
|
all_results = []
|
|
if debug:
|
|
# 调试模式:串行处理
|
|
logging.info("启用调试模式,使用串行处理...")
|
|
for i, row_data in enumerate(all_rows):
|
|
logging.info(f"处理第 {i+1}/{len(all_rows)} 行...")
|
|
result = self.validate_row(llm, row_data)
|
|
all_results.append(result)
|
|
# 实时更新DataFrame
|
|
index, is_correct, error_phase, error_reason, confidence_score = result
|
|
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
|
df.at[index, "错误环节"] = error_phase
|
|
df.at[index, "错误原因"] = error_reason
|
|
df.at[index, "置信度"] = confidence_score
|
|
# 输出当前结果
|
|
logging.info(f"行 {index} 验证结果: {'通过' if is_correct else '不通过'}, 错误环节: {error_phase}, 错误原因: {error_reason}, 置信度: {confidence_score:.2f}")
|
|
else:
|
|
# 正常模式:并行处理
|
|
batches = [all_rows[i:i+batch_size] for i in range(0, len(all_rows), batch_size)]
|
|
llm_instances = self.create_llm_instances(min(workers, len(batches)))
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
|
# 为每个批次分配一个LLM实例
|
|
future_to_batch = {
|
|
executor.submit(self.process_batch, llm_instances[i % len(llm_instances)], batch):
|
|
i for i, batch in enumerate(batches)
|
|
}
|
|
|
|
# 使用tqdm显示进度条
|
|
for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches), desc="批次处理进度"):
|
|
batch_results = future.result()
|
|
all_results.extend(batch_results)
|
|
|
|
# 按行索引排序结果,确保与原始数据顺序一致
|
|
all_results.sort(key=lambda x: x[0])
|
|
|
|
# 将结果填充到DataFrame
|
|
for result in all_results:
|
|
if len(result) >= 5:
|
|
index, is_correct, error_phase, error_reason, confidence_score = result
|
|
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
|
df.at[index, "错误环节"] = error_phase
|
|
df.at[index, "错误原因"] = error_reason
|
|
df.at[index, "置信度"] = confidence_score
|
|
else:
|
|
index, is_correct, error_phase, error_reason = result
|
|
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
|
df.at[index, "错误环节"] = error_phase
|
|
df.at[index, "错误原因"] = error_reason
|
|
|
|
# 保存结果
|
|
if output_file is None:
|
|
output_file = os.path.join(
|
|
os.path.dirname(input_file),
|
|
f"validated_{os.path.basename(input_file)}"
|
|
)
|
|
df.to_excel(output_file, index=False)
|
|
logging.info(f"验证完成,结果已保存至: {output_file}")
|
|
|
|
# 输出统计信息
|
|
self.print_statistics(df)
|
|
|
|
return df
|
|
|
|
def print_statistics(self, df):
|
|
"""打印统计信息"""
|
|
total = len(df)
|
|
passed = len(df[df["验证结果"] == "通过"])
|
|
error_stats = df[df["验证结果"] == "不通过"]["错误环节"].value_counts()
|
|
|
|
logging.info(f"统计信息: 总计 {total} 条, 通过 {passed} 条, 通过率 {passed/total*100:.2f}%")
|
|
logging.info("错误环节统计:")
|
|
for phase, count in error_stats.items():
|
|
logging.info(f"- {phase}: {count} 条")
|
|
|
|
def create_validation_parser(self):
|
|
"""创建验证结果解析器"""
|
|
return PydanticOutputParser(pydantic_object=ValidationResult)
|
|
|
|
|
|
def main():
|
|
"""主函数"""
|
|
# 解析命令行参数
|
|
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "1500条点踩软件问题测试_检索结果.xlsx")
|
|
output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx")
|
|
|
|
parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写")
|
|
parser.add_argument("--input", "-i", type=str, help="输入Excel文件路径", default=input_excel)
|
|
parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel)
|
|
parser.add_argument("--workers", "-w", type=int, default=20, help="并行工作线程数")
|
|
parser.add_argument("--batch-size", "-b", type=int, default=5, help="每批处理的行数")
|
|
parser.add_argument("--debug", "-d", action="store_true", help="启用调试模式(串行处理)")
|
|
|
|
args = parser.parse_args()
|
|
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
|
|
|
# 创建验证器实例并执行验证
|
|
validator = ExcelDataValidator(
|
|
input_file=args.input,
|
|
output_file=args.output,
|
|
workers=args.workers,
|
|
batch_size=args.batch_size,
|
|
debug=is_debug
|
|
)
|
|
validator.validate()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |