616 lines
23 KiB
Python
616 lines
23 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
完整性问题判断工具
|
|
|
|
此脚本用于读取Excel文件中的问题,调用LLM判断问题是否完整,并将结果保存到Excel文件中。
|
|
|
|
用法示例:
|
|
python judge_query_full.py -i "问题数据.xlsx" -o "完整问题结果.xlsx" -w 50 -c 0
|
|
|
|
命令行参数:
|
|
-i, --input: 输入Excel文件路径
|
|
-o, --output: 输出Excel文件路径
|
|
-w, --workers: 并发处理的最大线程数
|
|
-c, --column: 要处理的问题所在列的索引(从0开始)
|
|
-t, --test: 测试单个问题,不处理Excel文件
|
|
"""
|
|
|
|
import pandas as pd
|
|
import json
|
|
import os
|
|
import time
|
|
import re
|
|
import argparse
|
|
from pathlib import Path
|
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
|
from rag2_0.tool.APIKeyManager import APIKeyManager
|
|
from openpyxl.utils import get_column_letter
|
|
from openpyxl.styles import Alignment, PatternFill, Font, Border, Side
|
|
from tqdm import tqdm
|
|
import concurrent.futures
|
|
import threading
|
|
|
|
# 默认设置
|
|
DEFAULT_EXCEL_PATH = r"/data/QueryRewrite/data/excel/7000条对话数据.xlsx"
|
|
DEFAULT_OUTPUT_PATH = r"/data/QueryRewrite/data/excel/7000条对话数据_完整问题结果.xlsx"
|
|
DEFAULT_MAX_WORKERS = 50
|
|
|
|
|
|
class QueryCompletenessJudge:
|
|
"""
|
|
问题完整性判断工具类
|
|
|
|
用于评估问题是否完整,并将结果保存到Excel文件中。
|
|
可以批量处理Excel文件中的问题,也可以测试单个问题。
|
|
"""
|
|
|
|
def __init__(self, input_path=DEFAULT_EXCEL_PATH, output_path=DEFAULT_OUTPUT_PATH,
|
|
max_workers=DEFAULT_MAX_WORKERS, column_index=0):
|
|
"""
|
|
初始化问题完整性判断工具
|
|
|
|
参数:
|
|
input_path (str): 输入Excel文件路径
|
|
output_path (str): 输出Excel文件路径
|
|
max_workers (int): 并发处理的最大线程数
|
|
column_index (int): 要处理的问题所在列的索引(从0开始)
|
|
"""
|
|
self.input_path = input_path
|
|
self.output_path = output_path
|
|
self.max_workers = max_workers
|
|
self.column_index = column_index
|
|
self.llm_client = self._create_llm_client()
|
|
|
|
def _extract_json_from_response(self, full_answer):
|
|
"""
|
|
从LLM响应中提取JSON部分
|
|
|
|
参数:
|
|
full_answer (str): LLM的完整响应文本
|
|
|
|
返回:
|
|
dict: 解析后的JSON对象,如果解析失败则返回None
|
|
"""
|
|
# 尝试从回答中提取JSON部分
|
|
json_match = re.search(r'```json\s*(.*?)\s*```', full_answer, re.DOTALL)
|
|
if json_match:
|
|
json_str = json_match.group(1)
|
|
else:
|
|
# 如果没有找到```json```格式,尝试寻找普通的JSON对象
|
|
json_match = re.search(r'({[\s\S]*"is_complete"[\s\S]*})', full_answer)
|
|
if json_match:
|
|
json_str = json_match.group(1)
|
|
else:
|
|
# 如果仍然没有找到,返回None
|
|
return None
|
|
|
|
try:
|
|
# 解析JSON
|
|
return json.loads(json_str)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
|
|
def _create_llm_prompt(self, question):
|
|
"""
|
|
创建LLM提示词
|
|
|
|
参数:
|
|
question (str): 需要判断完整性的问题
|
|
|
|
返回:
|
|
str: 格式化后的提示词
|
|
"""
|
|
return f"""你是一个电力造价行业专家,用户正在使用电力造价软件,并提出了相关问题。请分析以下问题是否完整。
|
|
|
|
问题:{question}
|
|
|
|
首先,分析这个问题的结构和内容,思考它是否包含足够的信息来表达清晰的意图。
|
|
考虑以下几点:
|
|
1. 问题是否有明确的核心意图,不需要面面俱到
|
|
2. 问题是否缺少必要的上下文
|
|
3. **问题如果涉及软件相关,则只需要包含:软件名称、软件功能或软件目的即可**
|
|
|
|
|
|
在你的分析之后,请用JSON格式给出最终结论,格式如下:
|
|
```json
|
|
{{
|
|
"is_complete": true或false,
|
|
"reason": "判断原因的简要说明",
|
|
"confidence": 0到100之间的数值,表示你对判断的置信度
|
|
}}
|
|
```
|
|
|
|
请确保JSON格式正确,以便于程序解析。"""
|
|
|
|
def _create_llm_client(self, api_key=None):
|
|
"""
|
|
创建LLM客户端
|
|
|
|
参数:
|
|
api_key (str, optional): API密钥,如果为None则从APIKeyManager获取
|
|
|
|
返回:
|
|
OpenAiLLM: LLM客户端实例
|
|
"""
|
|
if api_key is None:
|
|
api_key = APIKeyManager.get_api_key()
|
|
|
|
return OpenAiLLM(
|
|
api_key=api_key,
|
|
base_url="https://api.siliconflow.cn/v1", # 可以根据实际情况修改
|
|
model="deepseek-ai/DeepSeek-V3", # 可以根据实际情况修改
|
|
temperature=0.2,
|
|
max_tokens=100
|
|
)
|
|
|
|
def is_question_complete(self, question):
|
|
"""
|
|
调用LLM判断问题是否完整
|
|
|
|
参数:
|
|
question (str): 需要判断的问题
|
|
|
|
返回:
|
|
tuple: (bool, str) - 是否完整的布尔值和LLM的详细回复
|
|
"""
|
|
# 最大重试次数
|
|
max_retries = 3
|
|
retry_count = 0
|
|
retry_delay = 2 # 重试延迟,单位:秒
|
|
|
|
while retry_count <= max_retries:
|
|
try:
|
|
# 创建提示词
|
|
prompt = self._create_llm_prompt(question)
|
|
|
|
# 使用OpenAiLLM调用模型
|
|
response = self.llm_client.invoke(prompt)
|
|
|
|
# 处理可能的响应格式
|
|
if hasattr(response, 'content'):
|
|
full_answer = response.content
|
|
else:
|
|
# 如果response是字符串
|
|
full_answer = str(response)
|
|
|
|
# 提取JSON部分
|
|
result = self._extract_json_from_response(full_answer)
|
|
|
|
if result:
|
|
is_complete = result.get("is_complete", False)
|
|
return is_complete, full_answer
|
|
else:
|
|
# 如果没有找到或解析失败,使用简单判断
|
|
is_complete = "完整" in full_answer[:100]
|
|
return is_complete, full_answer
|
|
|
|
except Exception as e:
|
|
retry_count += 1
|
|
if retry_count <= max_retries:
|
|
# 非最后一次重试,打印错误并继续
|
|
time.sleep(retry_delay)
|
|
# 每次重试增加延迟时间,避免频繁失败
|
|
retry_delay *= 2
|
|
else:
|
|
# 已达到最大重试次数,返回错误
|
|
print(f"错误: 经过 {max_retries} 次重试后仍然失败: {str(e)}")
|
|
return False, f"错误: 经过 {max_retries} 次重试后仍然失败: {str(e)}"
|
|
|
|
# 不应该到达这里,但为了代码完整性添加
|
|
return False, "未知错误:重试机制逻辑错误"
|
|
|
|
def _process_question(self, args, complete_questions, progress_counter, progress_lock, complete_questions_lock, pbar):
|
|
"""
|
|
处理单个问题并更新进度
|
|
|
|
参数:
|
|
args (tuple): 包含问题索引、问题内容、LLM客户端和总问题数的元组
|
|
complete_questions (list): 存储完整问题的列表
|
|
progress_counter (dict): 进度计数器
|
|
progress_lock (threading.Lock): 进度锁
|
|
complete_questions_lock (threading.Lock): 完整问题列表锁
|
|
pbar (tqdm): 进度条对象
|
|
"""
|
|
index, question, llm_client, total_questions = args
|
|
|
|
# 跳过空问题
|
|
if pd.isna(question) or question.strip() == "":
|
|
with progress_lock:
|
|
progress_counter["processed"] += 1
|
|
pbar.update(1)
|
|
return None
|
|
|
|
# 调用LLM判断问题是否完整
|
|
is_complete, full_answer = self.is_question_complete(question)
|
|
|
|
if is_complete:
|
|
# 从答案中提取JSON
|
|
parsed_json = self._extract_json_from_response(full_answer)
|
|
|
|
if parsed_json:
|
|
# 构造包含解析出的JSON信息的结果
|
|
result = {
|
|
"问题": question,
|
|
"LLM回复": full_answer,
|
|
"完整性": "完整" if parsed_json.get("is_complete", False) else "不完整",
|
|
"原因": parsed_json.get("reason", "未提供"),
|
|
"置信度": parsed_json.get("confidence", 0)
|
|
}
|
|
|
|
# 更新计数
|
|
with progress_lock:
|
|
if result["完整性"] == "完整":
|
|
progress_counter["complete"] += 1
|
|
else:
|
|
progress_counter["incomplete"] += 1
|
|
else:
|
|
# JSON解析失败,只保存原始回答
|
|
result = {
|
|
"问题": question,
|
|
"LLM回复": full_answer,
|
|
"完整性": "完整"
|
|
}
|
|
|
|
# 更新计数
|
|
with progress_lock:
|
|
progress_counter["complete"] += 1
|
|
|
|
with complete_questions_lock:
|
|
complete_questions.append(result)
|
|
else:
|
|
with progress_lock:
|
|
progress_counter["incomplete"] += 1
|
|
# 更新进度条
|
|
with progress_lock:
|
|
progress_counter["processed"] += 1
|
|
# 更新进度条描述
|
|
pbar.set_postfix(
|
|
完整=progress_counter["complete"],
|
|
不完整=progress_counter["incomplete"],
|
|
完整率=f"{progress_counter['complete']/max(1, progress_counter['processed']):.1%}"
|
|
)
|
|
pbar.update(1)
|
|
|
|
def _shorten_response(self, response):
|
|
"""
|
|
截断LLM响应,提取重要信息
|
|
|
|
参数:
|
|
response (str): 原始LLM响应
|
|
|
|
返回:
|
|
str: 截断后的响应
|
|
"""
|
|
# 保留思考过程的前200个字符和JSON部分
|
|
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
|
|
if json_match:
|
|
json_part = json_match.group(0)
|
|
prefix = response[:200] + "..." if len(response) > 200 else response
|
|
return f"{prefix}\n\n{json_part}"
|
|
return response[:500] + "..." if len(response) > 500 else response
|
|
|
|
def _prepare_excel_dataframe(self, complete_questions):
|
|
"""
|
|
将结果处理为DataFrame用于Excel输出
|
|
|
|
参数:
|
|
complete_questions (list): 完整问题列表
|
|
|
|
返回:
|
|
pandas.DataFrame: 处理后的DataFrame
|
|
"""
|
|
# 将结果列表转换为DataFrame
|
|
result_df = pd.DataFrame(complete_questions)
|
|
|
|
# 处理LLM回复列,截取一定长度以避免Excel单元格过大
|
|
if "LLM回复" in result_df.columns:
|
|
result_df["LLM回复"] = result_df["LLM回复"].apply(self._shorten_response)
|
|
|
|
# 调整列的顺序,确保重要列在前面
|
|
column_order = ["问题", "完整性", "置信度", "原因", "LLM回复"]
|
|
# 过滤掉不存在的列
|
|
column_order = [col for col in column_order if col in result_df.columns]
|
|
# 确保所有剩余的列也被包含
|
|
for col in result_df.columns:
|
|
if col not in column_order:
|
|
column_order.append(col)
|
|
|
|
# 重新排序列
|
|
return result_df[column_order]
|
|
|
|
def _set_excel_column_widths(self, worksheet):
|
|
"""
|
|
设置Excel列宽
|
|
|
|
参数:
|
|
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
|
"""
|
|
for col in range(1, worksheet.max_column + 1):
|
|
col_letter = get_column_letter(col)
|
|
column_name = worksheet[f"{col_letter}1"].value
|
|
|
|
if column_name == "问题":
|
|
worksheet.column_dimensions[col_letter].width = 40
|
|
elif column_name == "LLM回复":
|
|
worksheet.column_dimensions[col_letter].width = 60
|
|
elif column_name == "原因":
|
|
worksheet.column_dimensions[col_letter].width = 30
|
|
elif column_name == "完整性":
|
|
worksheet.column_dimensions[col_letter].width = 10
|
|
elif column_name == "置信度":
|
|
worksheet.column_dimensions[col_letter].width = 10
|
|
else:
|
|
worksheet.column_dimensions[col_letter].width = 15
|
|
|
|
def _apply_excel_cell_styles(self, worksheet):
|
|
"""
|
|
应用单元格样式
|
|
|
|
参数:
|
|
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
|
|
|
返回:
|
|
openpyxl.styles.Border: 边框样式,用于统计信息
|
|
"""
|
|
# 定义样式
|
|
header_fill = PatternFill(start_color="DDEBF7", end_color="DDEBF7", fill_type="solid")
|
|
header_font = Font(bold=True)
|
|
wrap_alignment = Alignment(wrap_text=True, vertical="top")
|
|
border = Border(
|
|
left=Side(style='thin'),
|
|
right=Side(style='thin'),
|
|
top=Side(style='thin'),
|
|
bottom=Side(style='thin')
|
|
)
|
|
|
|
# 应用样式到每个单元格
|
|
for row in worksheet.iter_rows(min_row=1, max_row=worksheet.max_row, min_col=1, max_col=worksheet.max_column):
|
|
for cell in row:
|
|
cell.alignment = wrap_alignment
|
|
cell.border = border
|
|
|
|
# 为标题行应用特殊样式
|
|
if cell.row == 1:
|
|
cell.fill = header_fill
|
|
cell.font = header_font
|
|
|
|
# 为完整性列应用条件格式
|
|
if cell.row > 1: # 跳过标题行
|
|
column_name = worksheet.cell(row=1, column=cell.column).value
|
|
if column_name == "完整性":
|
|
if cell.value == "完整":
|
|
cell.fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
|
|
else:
|
|
cell.fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
|
|
|
|
return border # 返回边框样式以便在统计信息中重用
|
|
|
|
def _add_statistics_to_excel(self, worksheet, complete_questions, total_rows, total_questions, border):
|
|
"""
|
|
添加统计信息到Excel表格
|
|
|
|
参数:
|
|
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
|
complete_questions (list): 完整问题列表
|
|
total_rows (int): 总行数
|
|
total_questions (int): 总问题数
|
|
border (openpyxl.styles.Border): 边框样式
|
|
|
|
返回:
|
|
int: 完整问题数量
|
|
"""
|
|
# 计算统计数据
|
|
complete_count = sum(1 for item in complete_questions if item.get("完整性") == "完整")
|
|
incomplete_count = total_rows - complete_count
|
|
|
|
# 添加统计行
|
|
worksheet.append([""]) # 空行
|
|
|
|
stat_row = worksheet.max_row + 1
|
|
worksheet.cell(row=stat_row, column=1, value="统计信息")
|
|
worksheet.cell(row=stat_row, column=1).font = Font(bold=True)
|
|
|
|
worksheet.cell(row=stat_row+1, column=1, value="总问题数")
|
|
worksheet.cell(row=stat_row+1, column=2, value=total_rows)
|
|
|
|
worksheet.cell(row=stat_row+2, column=1, value="完整问题数")
|
|
worksheet.cell(row=stat_row+2, column=2, value=complete_count)
|
|
worksheet.cell(row=stat_row+2, column=2).fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
|
|
|
|
worksheet.cell(row=stat_row+3, column=1, value="不完整问题数")
|
|
worksheet.cell(row=stat_row+3, column=2, value=incomplete_count)
|
|
worksheet.cell(row=stat_row+3, column=2).fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
|
|
|
|
worksheet.cell(row=stat_row+4, column=1, value="完整问题比例")
|
|
worksheet.cell(row=stat_row+4, column=2, value=f"{complete_count/total_rows:.2%}" if total_rows > 0 else "0%")
|
|
|
|
# 应用边框到统计行
|
|
for r in range(stat_row, stat_row+5):
|
|
for c in range(1, 3):
|
|
worksheet.cell(row=r, column=c).border = border
|
|
|
|
return complete_count
|
|
|
|
def save_results_to_excel(self, complete_questions, total_questions):
|
|
"""
|
|
将结果保存到Excel文件
|
|
|
|
参数:
|
|
complete_questions (list): 完整问题列表
|
|
total_questions (int): 总问题数
|
|
"""
|
|
if not complete_questions:
|
|
print(f"没有找到完整的问题。")
|
|
return
|
|
|
|
# 准备数据
|
|
result_df = self._prepare_excel_dataframe(complete_questions)
|
|
total_rows = len(result_df)
|
|
|
|
# 保存到Excel文件
|
|
result_df.to_excel(self.output_path, index=False, engine='openpyxl')
|
|
|
|
# 应用Excel样式
|
|
from openpyxl import load_workbook
|
|
wb = load_workbook(self.output_path)
|
|
ws = wb.active
|
|
|
|
# 设置列宽
|
|
self._set_excel_column_widths(ws)
|
|
|
|
# 应用单元格样式
|
|
border = self._apply_excel_cell_styles(ws)
|
|
|
|
# 添加统计信息
|
|
complete_count = self._add_statistics_to_excel(ws, complete_questions, total_rows, total_questions, border)
|
|
|
|
# 保存样式化的工作簿
|
|
wb.save(self.output_path)
|
|
|
|
# 输出结果统计
|
|
print(f"处理完成。共有{complete_count}/{total_questions}个完整问题被保存到 {self.output_path}")
|
|
print(f"完整问题比例: {complete_count/total_questions:.2%}" if total_questions > 0 else "完整问题比例: 0%")
|
|
|
|
def process_excel_file(self):
|
|
"""
|
|
处理Excel文件中的问题
|
|
|
|
读取Excel文件,判断问题完整性,并将结果保存到输出Excel文件
|
|
"""
|
|
# 确保Excel文件存在
|
|
if not os.path.exists(self.input_path):
|
|
print(f"错误: 找不到Excel文件 '{self.input_path}'")
|
|
return
|
|
|
|
# 读取Excel文件
|
|
print(f"正在读取Excel文件: {self.input_path}")
|
|
try:
|
|
df = pd.read_excel(self.input_path)
|
|
except Exception as e:
|
|
print(f"读取Excel文件时出错: {e}")
|
|
return
|
|
|
|
# 检查列数据
|
|
if len(df.columns) <= self.column_index:
|
|
print(f"错误: Excel文件没有足够的列,请求索引 {self.column_index},但只有 {len(df.columns)} 列")
|
|
return
|
|
|
|
# 获取目标列名称
|
|
target_col = df.columns[self.column_index]
|
|
print(f"目标列名称: {target_col}")
|
|
|
|
# 准备存储完整问题的列表
|
|
complete_questions = []
|
|
total_questions = len(df)
|
|
|
|
print(f"总共有{total_questions}个问题需要判断")
|
|
|
|
# 用于线程安全的列表操作和进度计数
|
|
complete_questions_lock = threading.Lock()
|
|
progress_counter = {"processed": 0, "complete": 0, "incomplete": 0}
|
|
progress_lock = threading.Lock()
|
|
|
|
# 准备问题列表
|
|
questions = [(i, str(row[target_col]), self.llm_client, total_questions)
|
|
for i, row in df.iterrows()]
|
|
|
|
# 记录开始时间
|
|
start_time = time.time()
|
|
|
|
# 使用tqdm创建进度条
|
|
print(f"开始处理问题,使用 {self.max_workers} 个并发线程...")
|
|
with tqdm(total=total_questions, desc="处理问题", unit="问题") as pbar:
|
|
# 使用线程池并发处理
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
|
# 提交所有任务
|
|
futures = [executor.submit(
|
|
self._process_question,
|
|
args,
|
|
complete_questions,
|
|
progress_counter,
|
|
progress_lock,
|
|
complete_questions_lock,
|
|
pbar
|
|
) for args in questions]
|
|
|
|
# 等待所有任务完成
|
|
concurrent.futures.wait(futures)
|
|
|
|
# 计算总处理时间
|
|
processing_time = time.time() - start_time
|
|
print(f"处理完成,耗时: {processing_time:.2f}秒,平均每问题: {processing_time/total_questions:.2f}秒")
|
|
|
|
# 将完整问题保存到Excel文件
|
|
self.save_results_to_excel(complete_questions, total_questions)
|
|
|
|
def test_single_question(self, question):
|
|
"""
|
|
测试单个问题的完整性
|
|
|
|
参数:
|
|
question (str): 要测试的问题
|
|
"""
|
|
print(f"问题: {question}")
|
|
print("正在调用LLM判断问题是否完整...")
|
|
|
|
# 调用LLM判断问题是否完整
|
|
is_complete, full_answer = self.is_question_complete(question)
|
|
|
|
# 从答案中提取JSON
|
|
parsed_json = self._extract_json_from_response(full_answer)
|
|
|
|
print("\n==== LLM回复 ====")
|
|
print(full_answer)
|
|
print("================\n")
|
|
|
|
if parsed_json:
|
|
print(f"判断结果: {'完整' if parsed_json.get('is_complete', False) else '不完整'}")
|
|
print(f"判断原因: {parsed_json.get('reason', '未提供')}")
|
|
print(f"置信度: {parsed_json.get('confidence', 0)}%")
|
|
else:
|
|
print(f"判断结果: {'完整' if is_complete else '不完整'} (简单判断)")
|
|
print("无法从回复中提取JSON结构化数据")
|
|
|
|
|
|
def parse_arguments():
|
|
"""解析命令行参数"""
|
|
parser = argparse.ArgumentParser(description='判断Excel文件中的问题是否完整')
|
|
parser.add_argument('-i', '--input', type=str, default=DEFAULT_EXCEL_PATH,
|
|
help=f'输入Excel文件路径 (默认: {DEFAULT_EXCEL_PATH})')
|
|
parser.add_argument('-o', '--output', type=str, default=DEFAULT_OUTPUT_PATH,
|
|
help=f'输出Excel文件路径 (默认: {DEFAULT_OUTPUT_PATH})')
|
|
parser.add_argument('-w', '--workers', type=int, default=DEFAULT_MAX_WORKERS,
|
|
help=f'并发处理的最大线程数 (默认: {DEFAULT_MAX_WORKERS})')
|
|
parser.add_argument('-c', '--column', type=int, default=0,
|
|
help='要处理的问题所在列的索引 (默认: 0,即第一列)')
|
|
parser.add_argument('-t', '--test', type=str,
|
|
help='测试单个问题,不处理Excel文件')
|
|
return parser.parse_args()
|
|
|
|
|
|
def main():
|
|
"""主函数"""
|
|
args = parse_arguments()
|
|
|
|
# 创建问题完整性判断工具实例
|
|
judge = QueryCompletenessJudge(
|
|
input_path=args.input,
|
|
output_path=args.output,
|
|
max_workers=args.workers,
|
|
column_index=args.column
|
|
)
|
|
# 如果是测试单个问题
|
|
if args.test:
|
|
judge.test_single_question(args.test)
|
|
return
|
|
|
|
# 处理Excel文件
|
|
judge.process_excel_file()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|
|
|
|
|