上传问题改写、意图识别模块代码

This commit is contained in:
2025-05-27 09:48:03 +08:00
commit 99017f0cb0
66 changed files with 111493 additions and 0 deletions
+408
View File
@@ -0,0 +1,408 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: validate_excel_data_batch.py
Description: 使用LLM批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写是否正确
"""
import os
import pandas as pd
import json
import argparse
import logging
import concurrent.futures
from tqdm import tqdm
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from rag2_0.intent_recognition.PromptTemplates import classification
from rag2_0.tool.ModelTool import OpenAiLLM
class ExcelDataValidator:
"""Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写"""
def __init__(self, input_file=None, output_file=None, workers=4, batch_size=10):
"""
初始化验证器
Args:
input_file: 输入Excel文件路径
output_file: 输出结果Excel文件路径
workers: 并行工作线程数
batch_size: 每批处理的行数
"""
# 加载环境变量
load_dotenv()
self.input_file = input_file
self.output_file = output_file
self.workers = workers
self.batch_size = batch_size
self.df = None
# 设置日志
self.setup_logging()
def setup_logging(self):
"""配置日志输出"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
def load_data_from_excel(self, file_path=None):
"""
从Excel文件中读取数据
Args:
file_path: Excel文件路径,如不提供则使用初始化时的路径
Returns:
DataFrame对象
"""
file_path = file_path or self.input_file
if not file_path:
logging.error("未指定输入文件路径")
return None
try:
df = pd.read_excel(file_path)
required_columns = ["提问", "问题拆解", "一级分类", "二级分类", "问题改写", "检索的关键词"]
for col in required_columns:
if col not in df.columns:
logging.error(f"缺少必要的列: {col}")
return None
logging.info(f"成功从{file_path}读取了{len(df)}条数据")
self.df = df
return df
except Exception as e:
logging.error(f"读取Excel文件时出错: {e}")
return None
def validate_classification(self, llm, query, vertical_class, sub_class):
"""
验证问题分类是否正确
Args:
llm: LLM模型
query: 原始问题
vertical_class: 一级分类
sub_class: 二级分类
Returns:
(bool, str): 是否正确,错误原因(如果有)
"""
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用,也可能涉及电力造价专业知识。我对用户问题进行了分类,请评估以下问题分类是否正确。
我目前总共有以下分类:
{classification}
问题的分类情况如下:
原始问题: {query}
一级分类: {vertical_class}
二级分类: {sub_class}
请从专业角度分析这个分类是否准确。只需返回"正确""错误:原因",不需要其他解释。"""
try:
response = llm.invoke(prompt)
result = response.content.strip()
if result.startswith("正确"):
return True, ""
else:
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
return False, error_reason
except Exception as e:
logging.warning(f"验证问题分类时出错: {e}")
return False, f"验证过程出错: {str(e)}"
def validate_query_keys(self, llm, query, query_keys):
"""
验证问题拆解是否正确
Args:
llm: LLM模型
query: 原始问题
query_keys: 问题拆解
Returns:
(bool, str): 是否正确,错误原因(如果有)
"""
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了拆解,请评估以下问题拆解是否正确。
原始问题: {query}
问题拆解: {query_keys}
问题拆解应该准确提取原始问题中的关键词和信息。请分析这个拆解是否准确。只需返回"正确""错误:原因",不需要其他解释。"""
try:
response = llm.invoke(prompt)
result = response.content.strip()
if result.startswith("正确"):
return True, ""
else:
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
return False, error_reason
except Exception as e:
logging.warning(f"验证问题拆解时出错: {e}")
return False, f"验证过程出错: {str(e)}"
def validate_keywords(self, llm, query, query_keys, keywords_str):
"""
验证检索关键词是否准确
Args:
llm: LLM模型
query: 原始问题
query_keys: 问题拆解
keywords_str: 检索关键词(JSON字符串)
Returns:
(bool, str): 是否正确,错误原因(如果有)
"""
# 解析关键词JSON
try:
if isinstance(keywords_str, str) and keywords_str.strip():
keywords = json.loads(keywords_str)
else:
keywords = []
except:
keywords = keywords_str
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。通过问题检索出了一些关键词,请评估这些关键词是否准确,是否与问题相关
原始问题: {query}
问题拆解: {query_keys}
检索关键词: {keywords}
检索关键词应该准确反映问题中需要检索的关键概念和术语。请分析这些关键词是否准确、完整。只需返回"正确""错误:原因",不需要其他解释。"""
try:
response = llm.invoke(prompt)
result = response.content.strip()
if result.startswith("正确"):
return True, ""
else:
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
return False, error_reason
except Exception as e:
logging.warning(f"验证检索关键词时出错: {e}")
return False, f"验证过程出错: {str(e)}"
def validate_rewrite(self, llm, query, rewrite):
"""
验证问题改写是否正确
Args:
llm: LLM模型
query: 原始问题
rewrite: 问题改写
Returns:
(bool, str): 是否正确,错误原因(如果有)
"""
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了改写,请评估以下问题改写是否正确。
原始问题: {query}
问题改写: {rewrite}
问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确。只需返回"正确""错误:原因",不需要其他解释。"""
try:
response = llm.invoke(prompt)
result = response.content.strip()
if result.startswith("正确"):
return True, ""
else:
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
return False, error_reason
except Exception as e:
logging.warning(f"验证问题改写时出错: {e}")
return False, f"验证过程出错: {str(e)}"
def validate_row(self, llm, row_data):
"""
按顺序验证一行数据中的各个环节
Args:
llm: LLM模型
row_data: (index, row)元组
Returns:
(index, is_all_correct, error_phase, error_reason): 行索引,是否全部正确,错误环节,错误原因
"""
index, row = row_data
query = row["提问"]
query_keys = row["问题拆解"]
vertical_class = row["一级分类"]
sub_class = row["二级分类"]
rewrite = row["问题改写"]
keywords_str = row["检索的关键词"]
try:
# 1. 验证问题分类
is_correct, error_reason = self.validate_classification(llm, query, vertical_class, sub_class)
if not is_correct:
return index, False, "问题分类", error_reason
# 2. 验证问题拆解
is_correct, error_reason = self.validate_query_keys(llm, query, query_keys)
if not is_correct:
return index, False, "问题拆解", error_reason
# 3. 验证检索关键词
is_correct, error_reason = self.validate_keywords(llm, query, query_keys, keywords_str)
if not is_correct:
return index, False, "关键词检索", error_reason
# 4. 验证问题改写
is_correct, error_reason = self.validate_rewrite(llm, query, rewrite)
if not is_correct:
return index, False, "问题改写", error_reason
return index, True, "", ""
except Exception as e:
error_msg = f"处理行 {index} 时发生错误: {str(e)}"
logging.error(error_msg)
return index, False, "处理错误", error_msg
def process_batch(self, llm, batch_data):
"""处理一批数据"""
results = []
for row_data in batch_data:
results.append(self.validate_row(llm, row_data))
return results
def create_llm_instances(self, count):
"""创建多个LLM实例"""
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
llm_params = {"temperature": 0.7, "model": model_name}
if api_key:
llm_params["api_key"] = api_key
if base_url:
llm_params["base_url"] = base_url
return [OpenAiLLM(**llm_params) for _ in range(count)]
def validate(self, input_file=None, output_file=None, workers=None, batch_size=None):
"""
执行验证过程
Args:
input_file: 输入Excel文件路径
output_file: 输出结果Excel文件路径
workers: 并行工作线程数
batch_size: 每批处理的行数
Returns:
验证后的DataFrame
"""
input_file = input_file or self.input_file
output_file = output_file or self.output_file
workers = workers or self.workers
batch_size = batch_size or self.batch_size
# 读取数据
df = self.load_data_from_excel(input_file)
if df is None:
return None
# 添加验证结果列
df["验证结果"] = ""
df["错误环节"] = ""
df["错误原因"] = ""
# 准备数据批次
all_rows = list(df.iterrows())
batches = [all_rows[i:i+batch_size] for i in range(0, len(all_rows), batch_size)]
# 创建多个LLM实例
llm_instances = self.create_llm_instances(min(workers, len(batches)))
# 使用线程池处理数据
all_results = []
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
# 为每个批次分配一个LLM实例
future_to_batch = {
executor.submit(self.process_batch, llm_instances[i % len(llm_instances)], batch):
i for i, batch in enumerate(batches)
}
# 使用tqdm显示进度条
for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches), desc="批次处理进度"):
batch_results = future.result()
all_results.extend(batch_results)
# 按行索引排序结果,确保与原始数据顺序一致
all_results.sort(key=lambda x: x[0])
# 将结果填充到DataFrame
for index, is_correct, error_phase, error_reason in all_results:
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
df.at[index, "错误环节"] = error_phase
df.at[index, "错误原因"] = error_reason
# 保存结果
if output_file is None:
output_file = os.path.join(
os.path.dirname(input_file),
f"validated_{os.path.basename(input_file)}"
)
df.to_excel(output_file, index=False)
logging.info(f"验证完成,结果已保存至: {output_file}")
# 输出统计信息
self.print_statistics(df)
return df
def print_statistics(self, df):
"""打印统计信息"""
total = len(df)
passed = len(df[df["验证结果"] == "通过"])
error_stats = df[df["验证结果"] == "不通过"]["错误环节"].value_counts()
logging.info(f"统计信息: 总计 {total} 条, 通过 {passed} 条, 通过率 {passed/total*100:.2f}%")
logging.info("错误环节统计:")
for phase, count in error_stats.items():
logging.info(f"- {phase}: {count}")
def main():
"""主函数"""
# 解析命令行参数
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "问题分类重写结果")
output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx")
parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写")
parser.add_argument("--input", "-i", type=str, required=True, help="输入Excel文件路径", default=input_excel)
parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel)
parser.add_argument("--workers", "-w", type=int, default=2, help="并行工作线程数")
parser.add_argument("--batch-size", "-b", type=int, default=5, help="每批处理的行数")
args = parser.parse_args()
# 创建验证器实例并执行验证
validator = ExcelDataValidator(
input_file=args.input,
output_file=args.output,
workers=args.workers,
batch_size=args.batch_size
)
validator.validate()
if __name__ == "__main__":
main()