更新环境变量配置,调整模型名称获取方式,新增Dify API相关配置,删除无用的脚本文件,优化意图识别逻辑,添加LLM提取词条逻辑

This commit is contained in:
2025-07-16 14:24:50 +08:00
parent 5e164882a1
commit a934f2c398
28 changed files with 1834 additions and 1235 deletions
+1 -1
View File
@@ -48,7 +48,7 @@ class JsonDeduplicator:
{items}
'''
# 配置LLM
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
llm_params = {"temperature": 0.3, "model": model_name}
-281
View File
@@ -1,281 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: extract_wikijs_nouns.py
Author: oyyz
Description: 从 Wikijs 文档中提取专业名词
"""
import os
from typing import List
from dotenv import load_dotenv
from langchain.output_parsers import PydanticOutputParser
from rag2_0.tool.WikijsTool import WikijsTool
from rag2_0.intent_recognition.DataModels import Term, TermList
from rag2_0.tool.html_to_md import convert_html_to_md
from rag2_0.tool.ModelTool import OpenAiLLM
import json
import datetime
import logging
import threading
import concurrent.futures
from threading import Semaphore
# 加载环境变量
load_dotenv()
extract_wiki_nouns_prompt="""
我在完善我的专业词库,请从提供的电力行业造价软件相关文本中提取关键词,要求如下:
一、提取范围
1. 核心功能模块
(例:多工程批量计价、材机数据反算、变电工程智能组价、架空线路地形系数计算)
2、软件功能及界面名称(包括:界面页签、功能按钮、功能名称等)
(例:新建工程量清单、导出工程量清单等)
3. 业务专用术语
(例:装置性材料、甲供材保管费、施工降效补偿、电缆头试验配套费)
4. 计价标准体系
(例:预规2020版、电网检修定额2015版、配网工程概算定额)
二、提取规则
1. 识别核心功能名称(如"多工程批量设置工程量、工程设置密码"
2. 提取业务专用名词(如"主材卸车保管费"
3. 标注关联术语的对应关系(如"市场价""市场价格"互为同义词)
4. 包含定额标准相关术语(如"预规2020版"
5. 复合型术语需保持完整
√ 正确:"地形增加系数批量设置"
× 错误:"地形""系数""设置"
6. 总结生成关键词解释
关键词:编制依据
描述:造价文件编制基准规范
7. 软件的特定版本号不作为关键词
三、输出格式:
{output_format}
四、输入内容:
{content}
"""
class WikijsNounsExtractor:
"""从 Wikijs 文档中提取专业名词"""
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo"):
"""
初始化专业名词提取器
Args:
api_key: API密钥,如果为None则从环境变量获取
base_url: API基础URL,如果为None则使用默认URL
model_name: 要使用的模型名称
"""
# 保存参数
self.api_key = api_key
self.base_url = base_url
self.model_name = model_name
# 初始化LLM
llm_params = {
"temperature": 0.6,
"model": model_name
}
if api_key:
llm_params["api_key"] = api_key
if base_url:
llm_params["base_url"] = base_url
self.llm = OpenAiLLM(**llm_params)
# 准备术语列表解析器
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
# 信号量,限制并发请求数量
self.semaphore = None
# 线程锁,用于保护共享资源
self.lock = threading.Lock()
def _convert_html_to_md(self, content, title):
"""HTML转Markdown"""
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
new_content = (content.replace("h6>", "h7>")
.replace("h5>", "h6>")
.replace("h4>", "h5>")
.replace("h3>", "h4>")
.replace("h2>", "h3>")
.replace("h1>", "h2>"))
# 将HTML内容转换为Markdown
markdown_content = convert_html_to_md(new_content, "", **options)
markdown_content = f"# {title}\n\n{markdown_content}"
return markdown_content
def extract_from_document(self, doc_info: dict) -> List[Term]:
"""从单个文档中提取专业名词"""
try:
# 使用LLM调用处理文档
content = doc_info['content']
title = doc_info["title"]
# 转换HTML到Markdown
markdown_content = self._convert_html_to_md(content, title)
# 准备提示词
formatted_prompt = extract_wiki_nouns_prompt.replace("{content}", markdown_content)
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
try:
# 调用LLM
response = self.llm.invoke(formatted_prompt)
# 使用Pydantic解析器解析结果
parsed_output = self.terms_list_parser.parse(response.content)
return parsed_output.terms
except Exception as e:
logging.error(f"解析LLM响应时出错: {str(e)}", exc_info=True)
return []
except Exception as e:
logging.error(f"提取专业名词时出错: {str(e)}", exc_info=True)
return []
def _process_document(self, doc, path_terms):
"""处理单个文档"""
try:
# 获取信号量
with self.semaphore:
# 检查文档路径是否在我们要处理的路径中
path_prefix = None
for prefix in path_terms.keys():
if doc['path'].startswith(prefix):
path_prefix = prefix
break
# 如果不在要处理的路径中,则跳过
if not path_prefix:
return None
# 获取文档详细信息
doc_info = WikijsTool.query_doc_info(doc['id'])
if not doc_info or not doc_info.get('content'):
return None
# 提取专业名词
terms = self.extract_from_document(doc_info)
# 将提取的术语添加到对应路径的结果列表中
terms_dicts = [{"name": term.name, "synonymous": term.synonymous, "description": term.description} for term in terms]
with self.lock:
path_terms[path_prefix].extend(terms_dicts)
logging.info(f"文档 {doc['path']} 处理完成,提取了 {len(terms)} 个专业名词")
# 每处理10个文档保存一次中间结果
current_count = len(path_terms[path_prefix])
if current_count % 10 == 0:
# 使用锁保护文件IO
self._save_terms_to_file(path_terms[path_prefix], os.path.join(self.output_dir, f"{path_prefix.split('')[0]}_nouns.json"))
logging.info(f"已处理 {path_prefix} 的文档数达到 {current_count//10*10} 个,已保存中间结果")
return path_prefix
except Exception as e:
logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}", exc_info=True)
return None
def process_all_documents(self, output_dir: str = "extracted_nouns", max_concurrency: int = 5):
"""使用线程池处理所有文档"""
# 保存输出目录
self.output_dir = output_dir
# 创建输出目录
if not os.path.exists(output_dir):
os.makedirs(output_dir)
# 初始化信号量,限制并发请求数
self.semaphore = Semaphore(max_concurrency)
# 获取所有文档
all_docs = WikijsTool.get_all_documents()
# 要处理的路径前缀
# path_prefixes = [
# "技改检修计价通(2020",
# "西藏造价软件(2023",
# "新型储能电站建设计价通C12024",
# "配网造价软件(2022",
# ]
path_prefixes = [
"主网电力建设计价通(2018",
]
# 为每个路径创建单独的结果列表
path_terms = {prefix: [] for prefix in path_prefixes}
# 过滤出符合路径前缀的文档
filtered_docs = []
for doc in all_docs:
for prefix in path_prefixes:
if doc['path'].startswith(prefix):
filtered_docs.append(doc)
break
logging.info(f"开始使用线程池处理 {len(filtered_docs)} 个文档...")
# 使用线程池处理所有文档
with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
futures = []
for doc in filtered_docs:
future = executor.submit(self._process_document, doc, path_terms)
futures.append(future)
# 等待所有任务完成
for i, future in enumerate(concurrent.futures.as_completed(futures)):
try:
prefix = future.result()
if i % 10 == 0:
logging.info(f"已完成 {i+1}/{len(futures)} 个文档的处理")
except Exception as e:
logging.error(f"处理文档时出错: {str(e)}", exc_info=True)
# 保存最终结果
for prefix, terms in path_terms.items():
# 为每个路径保存单独的文件
output_file = os.path.join(output_dir, f"{prefix.split('')[0]}_nouns.json")
self._save_terms_to_file(terms, output_file)
logging.info(f"{prefix} 处理完成,共提取 {len(terms)} 个专业名词,已保存到 {output_file}")
def _save_terms_to_file(self, terms, output_file):
"""保存术语列表到文件"""
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(terms, f, ensure_ascii=False, indent=2)
def main():
# 从环境变量获取配置
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
# os.environ["LLM_MODEL_NAME"] = "Qwen/Qwen2.5-72B-Instruct-128K"
extractor = WikijsNounsExtractor(api_key=api_key, base_url=base_url, model_name=os.getenv("LLM_MODEL_NAME"))
current_dir = os.path.dirname(os.path.abspath(__file__))
output_dir = os.path.join(current_dir, "..", "..", "data", "wiki_extracted_nouns")
extractor.process_all_documents(output_dir=output_dir, max_concurrency=2)
if __name__ == "__main__":
# 配置日志输出到文件,并设置格式
current_dir = os.path.dirname(os.path.abspath(__file__))
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
date_format = '%Y-%m-%d %H:%M:%S'
# 创建一个控制台处理器
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)
console_handler.setFormatter(logging.Formatter(log_format, date_format))
# 获取根日志记录器并添加处理器
root_logger = logging.getLogger()
root_logger.setLevel(logging.INFO)
root_logger.addHandler(console_handler)
main()
+13 -13
View File
@@ -75,15 +75,8 @@ class QueryRewriteProcessor:
dify_base_url: Dify API基础URL
"""
# 初始化意图识别器
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.base_url = base_url or os.getenv("OPENAI_API_BASE")
self.model_name = model_name or os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
# 使用asyncio.run()运行异步create方法
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create(
api_key=self.api_key,
base_url=self.base_url,
model_name=self.model_name
))
self.recognizer_async = asyncio.run(AsyncIntentRecognizer.create())
self.dify_query_retrieval = DifyQueryRetrieval(api_key=dify_api_key, base_url=dify_base_url)
def is_retrieved_doc_relevant(self, query: str, retrieved_doc: List[Dict[str, Any]]) -> Dict[str, Any]:
@@ -174,7 +167,7 @@ class QueryRewriteProcessor:
return []
def process_query(self, query: str,
conversation_context: str = "",
conversation_context: Dict = None,
chat_history: List[Dict[str, str]] = None,
previous_slots: Dict[str, str] = None,
enable_retrieval: bool = False):
@@ -196,12 +189,17 @@ class QueryRewriteProcessor:
while retry_count <= max_retries:
try:
if conversation_context is None:
conversation_context = {}
current_softname = conversation_context.get("current_softname", "")
result = asyncio.run(self.recognizer_async.process_query_async(query,
conversation_context=conversation_context,
chat_history=chat_history,
previous_slots=previous_slots,
enable_query_expansion=True,
use_jieba=True))
use_jieba=True,
cur_soft_name=current_softname))
# 提取分类信息
classification = result["classification"]
@@ -414,7 +412,7 @@ def main():
# 从环境变量中获取配置,命令行参数优先
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
enable_retrieval = args.enable_retrieval
# 初始化查询改写处理器
@@ -441,8 +439,10 @@ def main():
for idx, query in enumerate(examples):
if query.strip() == "":
continue
query="811619150828能看一下这个锁是16的马"
conversation_context="当前使用软件:配网计价通D3软件"
query="怎么把一个批次拆分成多个批次工程"
conversation_context={
"current_softname": "配网计价通D3软件"
}
# 在调试模式下使用完整的参数
print(json.dumps(processor.process_query(
query,
+1 -1
View File
@@ -44,7 +44,7 @@ class TermMerger:
{items}
'''
# 配置LLM
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
llm_params = {"temperature": 0.3, "model": model_name}
-573
View File
@@ -1,573 +0,0 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: validate_excel_data_batch.py
Description: 使用LLM批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写是否正确
"""
import os
import sys
import pandas as pd
import json
import argparse
import logging
import concurrent.futures
from tqdm import tqdm
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from pydantic import BaseModel, Field
from langchain.output_parsers import PydanticOutputParser
sys.path.append(os.getcwd())
from rag2_0.intent_recognition.PromptTemplates import classification_info
from rag2_0.intent_recognition.DataModels import *
from rag2_0.tool.ModelTool import OpenAiLLM
# 定义验证结果的Pydantic模型
class ValidationResult(BaseModel):
is_correct: bool = Field(description="验证是否通过")
confidence_score: float = Field(description="置信度得分")
reason: str = Field(default="", description="得出结论的原因")
class ExcelDataValidator:
"""Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写"""
def __init__(self, input_file=None, output_file=None, workers=4, debug=False):
"""
初始化验证器
Args:
input_file: 输入Excel文件路径
output_file: 输出结果Excel文件路径
workers: 并行工作线程数
debug: 是否启用调试模式(串行处理)
"""
# 加载环境变量
load_dotenv()
self.input_file = input_file
self.output_file = output_file
self.workers = workers
self.debug = debug
self.df = None
# 设置日志
self.setup_logging()
def setup_logging(self):
"""配置日志输出"""
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler()
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
def load_data_from_excel(self, file_path=None):
"""
从Excel文件中读取数据
Args:
file_path: Excel文件路径,如不提供则使用初始化时的路径
Returns:
DataFrame对象
"""
file_path = file_path or self.input_file
if not file_path:
logging.error("未指定输入文件路径", exc_info=True)
return None
try:
df = pd.read_excel(file_path)
required_columns = ["问题", "问题分类", "问题改写", "槽位信息", "检索的内容"]
for col in required_columns:
if col not in df.columns:
logging.error(f"缺少必要的列: {col}", exc_info=True)
return None
logging.info(f"成功从{file_path}读取了{len(df)}条数据")
self.df = df
return df
except Exception as e:
logging.error(f"读取Excel文件时出错: {e}", exc_info=True)
return None
def validate_classification(self, llm:OpenAiLLM , query:str, vertical_class:str, sub_class:str):
"""
验证问题分类是否正确
Args:
llm: LLM模型
query: 原始问题
vertical_class: 一级分类
sub_class: 二级分类
Returns:
(bool, str, float): 是否正确,错误原因(如果有),置信度
"""
parser = self.create_validation_parser()
format_instructions = parser.get_format_instructions()
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用,也可能涉及电力造价专业知识。我对用户问题进行了分类,请评估以下问题分类是否正确。
我目前总共有以下分类:
{classification_info}
问题的分类情况如下:
原始问题: {query}
一级分类: {vertical_class}
二级分类: {sub_class}
请从专业角度分析这个分类是否准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
{format_instructions}
"""
try:
response = llm.invoke(prompt)
result = parser.parse(response.content)
return result.is_correct, result.reason, result.confidence_score
except Exception as e:
logging.warning(f"验证问题分类时出错: {e}")
return False, f"验证过程出错: {str(e)}", 0.0
def _get_slot_model(self, classification: Classification) -> Optional[type]:
"""
根据分类结果获取对应的槽位模型类,用于统一提示词处理
Args:
classification: 意图分类结果
Returns:
对应的槽位模型类
"""
# 软件问题
if classification.vertical_classification == "软件问题":
if classification.sub_classification == "软件功能":
return SoftwareFunctionSlots
elif classification.sub_classification == "故障排查":
return SoftwareTroubleShootingSlots
# 业务问题
elif classification.vertical_classification == "业务问题":
if classification.sub_classification == "专业咨询":
return ProfessionalConsultingSlots
elif classification.sub_classification == "数据问题":
return DataProblemSlots
# 安装下载注册
elif classification.vertical_classification == "安装下载注册":
if classification.sub_classification == "后缀名咨询":
return FileExtensionConsultingSlots
elif classification.sub_classification == "软件锁类":
return SoftwareLockSlots
elif classification.sub_classification == "安装下载类":
return InstallationDownloadSlots
elif classification.sub_classification == "问题排查类":
return ProblemDiagnosisSlots
# 其他
elif classification.vertical_classification == "其他":
return OtherSlots
return None
def validate_slot(self, llm, rewrite, slot_info, vertical_class, sub_class):
"""
验证槽位填充是否正确
Args:
llm: LLM模型
rewrite: 问题改写
slot_info: 槽位信息(JSON字符串)
Returns:
(bool, str, float): 是否正确,错误原因(如果有),置信度
"""
# 解析槽位信息JSON
try:
if isinstance(slot_info, str) and slot_info.strip():
slots = json.loads(slot_info)
else:
slots = slot_info
except:
slots = slot_info
parser = self.create_validation_parser()
format_instructions = parser.get_format_instructions()
slot_info_prompt = self._get_slot_model(Classification(vertical_classification=vertical_class, sub_classification=sub_class)).model_json_schema()
slot_info_prompt = json.dumps(slot_info_prompt, ensure_ascii=False)
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我从用户问题中提取了槽位信息,请评估这些槽位信息是否准确、完整。
问题改写: {rewrite}
槽位模板:{slot_info_prompt}
填充的槽位信息: {slots}
槽位信息应该准确提取问题中的关键实体和属性,如软件名称、功能名称、错误信息等。请分析这些槽位是否准确填充,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
{format_instructions}
"""
try:
response = llm.invoke(prompt)
result = parser.parse(response.content)
return result.is_correct, result.reason, result.confidence_score
except Exception as e:
logging.warning(f"验证槽位填充时出错: {e}")
return False, f"验证过程出错: {str(e)}", 0.0
def validate_retrieve_content(self, llm, rewrite, retrieve_content):
"""
验证检索内容是否正确
Args:
llm: LLM模型
rewrite: 问题改写
retrieve_content: 检索内容(可能是JSON字符串或文本)
Returns:
(bool, str, float): 是否正确,错误原因(如果有),置信度
"""
# 解析检索内容
try:
if isinstance(retrieve_content, str) and retrieve_content.strip():
if retrieve_content.startswith('{') or retrieve_content.startswith('['):
content = json.loads(retrieve_content)
else:
content = retrieve_content
else:
content = retrieve_content
except:
content = retrieve_content
parser = self.create_validation_parser()
format_instructions = parser.get_format_instructions()
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我针对用户问题检索了相关内容,请评估这些检索内容是否能解答提问。
问题改写: {rewrite}
检索内容: {content}
检索内容应该与问题主题相关,能够提供有用的信息来回答问题。请分析检索内容是否能解答提问、准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
{format_instructions}
"""
try:
response = llm.invoke(prompt)
result = parser.parse(response.content)
return result.is_correct, result.reason, result.confidence_score
except Exception as e:
logging.warning(f"验证检索内容时出错: {e}")
return False, f"验证过程出错: {str(e)}", 0.0
def validate_rewrite(self, llm, query, rewrite):
"""
验证问题改写是否正确
Args:
llm: LLM模型
query: 原始问题
rewrite: 问题改写
Returns:
(bool, str, float): 是否正确,错误原因(如果有),置信度
"""
parser = self.create_validation_parser()
format_instructions = parser.get_format_instructions()
prompt = f"""
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了改写,请评估以下问题改写是否正确。
原始问题: {query}
问题改写: {rewrite}
问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确,并以JSON格式返回结果。请提供一个0到1之间的置信度得分,表示你对判断的确信程度。
{format_instructions}
"""
try:
response = llm.invoke(prompt)
result = parser.parse(response.content)
return result.is_correct, result.reason, result.confidence_score
except Exception as e:
logging.warning(f"验证问题改写时出错: {e}")
return False, f"验证过程出错: {str(e)}", 0.0
def validate_row(self, llm, row_data):
"""
按顺序验证一行数据中的各个环节
Args:
llm: LLM模型
row_data: (index, row)元组
Returns:
(index, is_all_correct, error_phase, error_reason, confidence_score): 行索引,是否全部正确,错误环节,错误原因,置信度
"""
index, row = row_data
query = row["问题"]
query_class = row.get("问题分类", "")
rewrite = row.get("问题改写", "")
slot_info = row.get("槽位信息", "")
retrieve_content = row.get("检索的内容", "")
if self.debug:
logging.info(f"开始验证行 {index}:")
logging.info(f" 问题: {query}")
logging.info(f" 问题分类: {query_class}")
logging.info(f" 问题改写: {rewrite}")
try:
confidence_score = 0.0
# 1. 验证问题改写
if rewrite:
if self.debug:
logging.info(f" 验证问题改写...")
result = self.validate_rewrite(llm, query, rewrite)
if isinstance(result, tuple) and len(result) >= 3:
is_correct, error_reason, rewrite_confidence = result[:3]
confidence_score = max(confidence_score, rewrite_confidence)
if self.debug:
logging.info(f" 问题改写验证结果: {'通过' if is_correct else '不通过'}, 置信度: {rewrite_confidence:.2f}")
if not is_correct:
logging.info(f" 错误原因: {error_reason}")
if not is_correct:
return index, False, "问题改写", error_reason, rewrite_confidence
# 2. 验证问题分类
if query_class:
if self.debug:
logging.info(f" 验证问题分类...")
query_class_list = query_class.split(" - ")
if len(query_class_list) >= 2:
result = self.validate_classification(llm, rewrite, query_class_list[0], query_class_list[1])
if isinstance(result, tuple) and len(result) >= 3:
is_correct, error_reason, classification_confidence = result[:3]
confidence_score = max(confidence_score, classification_confidence)
if self.debug:
logging.info(f" 问题分类验证结果: {'通过' if is_correct else '不通过'}, 置信度: {classification_confidence:.2f}")
if not is_correct:
logging.info(f" 错误原因: {error_reason}")
if not is_correct:
return index, False, "问题分类", error_reason, classification_confidence
# 3. 验证槽位填充
if slot_info:
if self.debug:
logging.info(f" 验证槽位填充...")
result = self.validate_slot(llm, rewrite, slot_info, query_class_list[0], query_class_list[1])
if isinstance(result, tuple) and len(result) >= 3:
is_correct, error_reason, slot_confidence = result[:3]
confidence_score = max(confidence_score, slot_confidence)
if self.debug:
logging.info(f" 槽位填充验证结果: {'通过' if is_correct else '不通过'}, 置信度: {slot_confidence:.2f}")
if not is_correct:
logging.info(f" 错误原因: {error_reason}")
if not is_correct:
return index, False, "槽位填充", error_reason, slot_confidence
# 4. 验证检索内容
if retrieve_content and retrieve_content != "" and pd.notna(retrieve_content):
if self.debug:
logging.info(f" 验证检索内容...")
result = self.validate_retrieve_content(llm, query, retrieve_content)
if isinstance(result, tuple) and len(result) >= 3:
is_correct, error_reason, retrieve_confidence = result[:3]
confidence_score = max(confidence_score, retrieve_confidence)
if self.debug:
logging.info(f" 检索内容验证结果: {'通过' if is_correct else '不通过'}, 置信度: {retrieve_confidence:.2f}")
if not is_correct:
logging.info(f" 错误原因: {error_reason}")
if not is_correct:
return index, False, "检索内容", error_reason, retrieve_confidence
if self.debug:
logging.info(f"{index} 验证完成: 通过, 总置信度: {confidence_score:.2f}")
return index, True, "", "", confidence_score
except Exception as e:
error_msg = f"处理行 {index} 时发生错误: {str(e)}"
logging.error(error_msg, exc_info=True)
return index, False, "处理错误", error_msg, 0.0
def create_llm_instances(self, count):
"""创建多个LLM实例"""
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model_name = "deepseek-ai/DeepSeek-R1"
llm_params = {"temperature": 0.7, "model": model_name}
if api_key:
llm_params["api_key"] = api_key
if base_url:
llm_params["base_url"] = base_url
return [OpenAiLLM(**llm_params) for _ in range(count)]
def validate(self, input_file=None, output_file=None, workers=None, debug=None):
"""
执行验证过程
Args:
input_file: 输入Excel文件路径
output_file: 输出结果Excel文件路径
workers: 并行工作线程数
batch_size: 每批处理的行数(已弃用,保留参数保持兼容)
debug: 是否启用调试模式(串行处理)
Returns:
验证后的DataFrame
"""
input_file = input_file or self.input_file
output_file = output_file or self.output_file
workers = workers or self.workers
debug = debug if debug is not None else self.debug
# 读取数据
df = self.load_data_from_excel(input_file)
if df is None:
return None
# 添加验证结果列
df["验证结果"] = ""
df["错误环节"] = ""
df["错误原因"] = ""
df["置信度"] = 0.0
# 准备数据
all_rows = list(df.iterrows())
# 创建LLM实例
llm = self.create_llm_instances(1)[0]
# 根据模式选择处理方式
all_results = []
if debug:
# 调试模式:串行处理
logging.info("启用调试模式,使用串行处理...")
for i, row_data in enumerate(all_rows):
logging.info(f"处理第 {i+1}/{len(all_rows)} 行...")
result = self.validate_row(llm, row_data)
all_results.append(result)
# 实时更新DataFrame
index, is_correct, error_phase, error_reason, confidence_score = result
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
df.at[index, "错误环节"] = error_phase
df.at[index, "错误原因"] = error_reason
df.at[index, "置信度"] = confidence_score
# 输出当前结果
logging.info(f"{index} 验证结果: {'通过' if is_correct else '不通过'}, 错误环节: {error_phase}, 错误原因: {error_reason}, 置信度: {confidence_score:.2f}")
else:
# 正常模式:并行处理,每行单独处理
llm_instances = self.create_llm_instances(min(workers, len(all_rows)))
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
# 为每行分配一个LLM实例
future_to_row = {
executor.submit(self.validate_row, llm_instances[i % len(llm_instances)], row_data):
i for i, row_data in enumerate(all_rows)
}
# 使用tqdm显示进度条
for future in tqdm(concurrent.futures.as_completed(future_to_row), total=len(all_rows), desc="处理进度"):
result = future.result()
all_results.append(result)
# 按行索引排序结果,确保与原始数据顺序一致
all_results.sort(key=lambda x: x[0])
# 将结果填充到DataFrame
for result in all_results:
if len(result) >= 5:
index, is_correct, error_phase, error_reason, confidence_score = result
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
df.at[index, "错误环节"] = error_phase
df.at[index, "错误原因"] = error_reason
df.at[index, "置信度"] = confidence_score
else:
index, is_correct, error_phase, error_reason = result
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
df.at[index, "错误环节"] = error_phase
df.at[index, "错误原因"] = error_reason
# 保存结果
if output_file is None:
output_file = os.path.join(
os.path.dirname(input_file),
f"validated_{os.path.basename(input_file)}"
)
df.to_excel(output_file, index=False)
logging.info(f"验证完成,结果已保存至: {output_file}")
# 输出统计信息
self.print_statistics(df)
return df
def print_statistics(self, df):
"""打印统计信息"""
total = len(df)
passed = len(df[df["验证结果"] == "通过"])
error_stats = df[df["验证结果"] == "不通过"]["错误环节"].value_counts()
logging.info(f"统计信息: 总计 {total} 条, 通过 {passed} 条, 通过率 {passed/total*100:.2f}%")
logging.info("错误环节统计:")
for phase, count in error_stats.items():
logging.info(f"- {phase}: {count}")
def create_validation_parser(self):
"""创建验证结果解析器"""
return PydanticOutputParser(pydantic_object=ValidationResult)
def main():
"""主函数"""
# 解析命令行参数
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "1500条点踩软件问题测试_意图分类.xlsx")
output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx")
parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写")
parser.add_argument("--input", "-i", type=str, help="输入Excel文件路径", default=input_excel)
parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel)
parser.add_argument("--workers", "-w", type=int, default=20, help="并行工作线程数")
args = parser.parse_args()
logging.info(f"输入文件路径: {args.input}, 输出文件路径: {args.output}, 并行工作线程数: {args.workers}")
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
# 创建验证器实例并执行验证
validator = ExcelDataValidator(
input_file=args.input,
output_file=args.output,
workers=args.workers,
debug=is_debug
)
validator.validate()
if __name__ == "__main__":
main()