上传问题改写、意图识别模块代码
This commit is contained in:
@@ -0,0 +1,250 @@
|
||||
"""
|
||||
提问内容补全工具
|
||||
|
||||
此模块用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||
并将原提问和补全后的提问保存到新的Excel文件中。
|
||||
|
||||
用法示例:
|
||||
completer = QuestionCompleter(input_path="历史提问数据(dislike).xlsx", output_path="补全后的提问数据.xlsx")
|
||||
completer.process()
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
import concurrent.futures
|
||||
from threading import Lock
|
||||
|
||||
class RewriteQuery(BaseModel):
|
||||
rewrite_query:str = Field(description="补全后的提问")
|
||||
software_name:str = Field(description="软件名称")
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
class QuestionCompleter:
|
||||
"""
|
||||
提问内容补全工具类
|
||||
|
||||
用于解析Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||
并将原提问和补全后的提问保存到新的Excel文件中。
|
||||
"""
|
||||
|
||||
def __init__(self, input_path="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
|
||||
output_path="/data/Rag2_0/data/excel/历史提问数据(dislike)_补全后的提问数据.xlsx",
|
||||
question_column="提问", answer_column="回答", max_workers=10):
|
||||
"""
|
||||
初始化提问内容补全工具
|
||||
|
||||
参数:
|
||||
input_path (str): 输入Excel文件路径
|
||||
output_path (str): 输出Excel文件路径
|
||||
question_column (str): 提问列的名称
|
||||
answer_column (str): 回答列的名称
|
||||
max_workers (int): 最大线程数
|
||||
"""
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
self.question_column = question_column
|
||||
self.answer_column = answer_column
|
||||
self.max_workers = max_workers
|
||||
self.rewrite_query_parser = PydanticOutputParser(pydantic_object=RewriteQuery)
|
||||
self.lock = Lock() # 添加线程锁
|
||||
|
||||
# 初始化LLM
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = os.getenv("OPENAI_API_BASE")
|
||||
self.model = os.getenv("LLM_MODEL_NAME")
|
||||
|
||||
if not all([self.api_key, self.base_url, self.model]):
|
||||
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
|
||||
|
||||
self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model)
|
||||
|
||||
# 读取Excel文件
|
||||
try:
|
||||
self.df = pd.read_excel(self.input_path)
|
||||
print(f"成功读取Excel文件: {self.input_path}")
|
||||
print(f"共有 {len(self.df)} 条记录")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"读取Excel文件失败: {str(e)}")
|
||||
|
||||
# 检查列是否存在
|
||||
if self.question_column not in self.df.columns:
|
||||
raise ValueError(f"Excel文件中不存在列: {self.question_column}")
|
||||
if self.answer_column not in self.df.columns:
|
||||
raise ValueError(f"Excel文件中不存在列: {self.answer_column}")
|
||||
|
||||
def create_completion_prompt(self, question, answer):
|
||||
"""
|
||||
创建用于补全提问的prompt
|
||||
|
||||
参数:
|
||||
question (str): 原始提问
|
||||
answer (str): 对应的回答
|
||||
|
||||
返回:
|
||||
str: 格式化的prompt
|
||||
"""
|
||||
prompt = f"""
|
||||
1、判断提问中是否缺少软件名称,如果不缺少,则直接返回原始提问
|
||||
2、如果缺少软件名称,则根据回答中的软件名称,补全提问
|
||||
3、补全后的提问需要保持问题原有意图不变
|
||||
|
||||
4、软件名称包括:
|
||||
配网D3软件(配网工程计价通D3)
|
||||
西藏Z1软件(西藏电力工程计价通Z1)
|
||||
主网计价通软件(电力建设计价通)
|
||||
技改检修工程计价通T1软件(技改检修工程计价通T1)
|
||||
技改检修清单计价通T1软件(技改检修清单计价通T1)
|
||||
储能C1软件(新型储能电站建设计价通C1)
|
||||
如果没有包含上述软件名称,则直接返回原始提问,software_name为空字符串
|
||||
{{
|
||||
"rewrite_query": "xxx",
|
||||
"software_name": ""
|
||||
}}
|
||||
|
||||
原始提问:{question}
|
||||
系统回答:{answer}
|
||||
|
||||
输出格式:
|
||||
{self.rewrite_query_parser.get_format_instructions()}
|
||||
|
||||
示例:
|
||||
例如,如果输入是:
|
||||
提问:这个软件怎么用?
|
||||
回答:Photoshop的使用方法是...
|
||||
|
||||
那么输出会是:
|
||||
{{
|
||||
"rewrite_query": "Photoshop这个软件怎么用?",
|
||||
"software_name": "Photoshop"
|
||||
}}
|
||||
|
||||
或者如果提问已经包含软件名称:
|
||||
提问:Photoshop怎么用?
|
||||
回答:Photoshop的使用方法是...
|
||||
|
||||
那么输出会是:
|
||||
{{
|
||||
"rewrite_query": "Photoshop怎么用?",
|
||||
"software_name": "Photoshop"
|
||||
}}
|
||||
|
||||
"""
|
||||
return prompt
|
||||
|
||||
def complete_question(self, question, answer):
|
||||
"""
|
||||
调用LLM补全提问内容
|
||||
|
||||
参数:
|
||||
question (str): 原始提问
|
||||
answer (str): 对应的回答
|
||||
|
||||
返回:
|
||||
str: 补全后的提问,如果补全失败则返回原始提问
|
||||
"""
|
||||
# 如果提问或回答为空,直接返回原始提问
|
||||
if pd.isna(question) or question.strip() == "" or pd.isna(answer) or answer.strip() == "":
|
||||
return question, ""
|
||||
|
||||
try:
|
||||
prompt = self.create_completion_prompt(question, answer)
|
||||
response = self.llm.invoke(prompt)
|
||||
completed_question = self.rewrite_query_parser.parse(response.content)
|
||||
return completed_question.rewrite_query, completed_question.software_name
|
||||
except Exception as e:
|
||||
print(f"补全提问失败: {str(e)}")
|
||||
return question, ""
|
||||
|
||||
def process_row(self, row):
|
||||
"""
|
||||
处理单行数据
|
||||
|
||||
参数:
|
||||
row: DataFrame中的一行
|
||||
|
||||
返回:
|
||||
dict: 处理结果
|
||||
"""
|
||||
original_question = row[self.question_column]
|
||||
answer = row[self.answer_column]
|
||||
|
||||
# 调用LLM补全提问
|
||||
completed_question, software_name = self.complete_question(original_question, answer)
|
||||
|
||||
# 创建结果字典
|
||||
result = {
|
||||
"原始提问": original_question,
|
||||
"补全后的提问": completed_question,
|
||||
"软件名称": software_name
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
使用多线程处理所有提问并补全内容
|
||||
|
||||
读取Excel文件中的提问和回答,调用LLM补全提问内容,
|
||||
并将原提问和补全后的提问保存到新的Excel文件中
|
||||
"""
|
||||
results = []
|
||||
total = len(self.df)
|
||||
|
||||
# 使用进度条显示总体进度
|
||||
with tqdm(total=total, desc="补全提问") as pbar:
|
||||
# 创建线程池
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# 提交所有任务
|
||||
future_to_idx = {executor.submit(self.process_row, self.df.iloc[idx]): idx for idx in range(total)}
|
||||
|
||||
# 处理完成的任务
|
||||
for future in concurrent.futures.as_completed(future_to_idx):
|
||||
result = future.result()
|
||||
with self.lock:
|
||||
results.append(result)
|
||||
pbar.update(1)
|
||||
|
||||
# 将结果转换为DataFrame并保存
|
||||
results_df = pd.DataFrame(results)
|
||||
results_df.to_excel(self.output_path, index=False)
|
||||
print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}")
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description='补全Excel文件中的提问内容')
|
||||
parser.add_argument('-i', '--input', type=str, default="/data/Rag2_0/data/excel/历史提问数据(dislike).xlsx",
|
||||
help='输入Excel文件路径')
|
||||
parser.add_argument('-o', '--output', type=str, default="/data/Rag2_0/data/excel/补全后的提问数据.xlsx",
|
||||
help='输出Excel文件路径')
|
||||
parser.add_argument('-q', '--question', type=str, default="提问",
|
||||
help='提问列的名称')
|
||||
parser.add_argument('-a', '--answer', type=str, default="回答",
|
||||
help='回答列的名称')
|
||||
parser.add_argument('-w', '--workers', type=int, default=50,
|
||||
help='最大线程数')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建提问补全工具实例
|
||||
completer = QuestionCompleter(
|
||||
input_path=args.input,
|
||||
output_path=args.output,
|
||||
question_column=args.question,
|
||||
answer_column=args.answer,
|
||||
max_workers=args.workers
|
||||
)
|
||||
|
||||
# 执行处理
|
||||
completer.process()
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,282 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: extract_wikijs_nouns.py
|
||||
Author: oyyz
|
||||
Description: 从 Wikijs 文档中提取专业名词
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List
|
||||
from dotenv import load_dotenv
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from rag2_0.tool.WikijsTool import WikijsTool
|
||||
from rag2_0.intent_recognition.DataModels import Term, TermList
|
||||
from rag2_0.tool.html_to_md import convert_html_to_md
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
import json
|
||||
import datetime
|
||||
import logging
|
||||
import threading
|
||||
import concurrent.futures
|
||||
from threading import Semaphore
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
extract_wiki_nouns_prompt="""
|
||||
我在完善我的专业词库,请从提供的电力行业造价软件相关文本中提取关键词,要求如下:
|
||||
|
||||
一、提取范围
|
||||
1. 核心功能模块
|
||||
(例:多工程批量计价、材机数据反算、变电工程智能组价、架空线路地形系数计算)
|
||||
2、软件功能及界面名称(包括:界面页签、功能按钮、功能名称等)
|
||||
(例:新建工程量清单、导出工程量清单等)
|
||||
3. 业务专用术语
|
||||
(例:装置性材料、甲供材保管费、施工降效补偿、电缆头试验配套费)
|
||||
4. 计价标准体系
|
||||
(例:预规2020版、电网检修定额2015版、配网工程概算定额)
|
||||
|
||||
|
||||
二、提取规则
|
||||
1. 识别核心功能名称(如"多工程批量设置工程量、工程设置密码")
|
||||
2. 提取业务专用名词(如"主材卸车保管费")
|
||||
3. 标注关联术语的对应关系(如"市场价"与"市场价格"互为同义词)
|
||||
4. 包含定额标准相关术语(如"预规2020版")
|
||||
5. 复合型术语需保持完整
|
||||
√ 正确:"地形增加系数批量设置"
|
||||
× 错误:"地形"、"系数"、"设置"
|
||||
6. 总结生成关键词解释
|
||||
关键词:编制依据
|
||||
描述:造价文件编制基准规范
|
||||
|
||||
7. 软件的特定版本号不作为关键词
|
||||
|
||||
三、输出格式:
|
||||
{output_format}
|
||||
|
||||
四、输入内容:
|
||||
{content}
|
||||
"""
|
||||
|
||||
|
||||
class WikijsNounsExtractor:
|
||||
"""从 Wikijs 文档中提取专业名词"""
|
||||
|
||||
def __init__(self, api_key: str = None, base_url: str = None, model_name: str = "gpt-3.5-turbo"):
|
||||
"""
|
||||
初始化专业名词提取器
|
||||
|
||||
Args:
|
||||
api_key: API密钥,如果为None则从环境变量获取
|
||||
base_url: API基础URL,如果为None则使用默认URL
|
||||
model_name: 要使用的模型名称
|
||||
"""
|
||||
# 保存参数
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.model_name = model_name
|
||||
|
||||
# 初始化LLM
|
||||
llm_params = {
|
||||
"temperature": 0.6,
|
||||
"model": model_name
|
||||
}
|
||||
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
|
||||
self.llm = OpenAiLLM(**llm_params)
|
||||
|
||||
# 准备术语列表解析器
|
||||
self.terms_list_parser = PydanticOutputParser(pydantic_object=TermList)
|
||||
|
||||
# 信号量,限制并发请求数量
|
||||
self.semaphore = None
|
||||
|
||||
# 线程锁,用于保护共享资源
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def _convert_html_to_md(self, content, title):
|
||||
"""HTML转Markdown"""
|
||||
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
|
||||
new_content = (content.replace("h6>", "h7>")
|
||||
.replace("h5>", "h6>")
|
||||
.replace("h4>", "h5>")
|
||||
.replace("h3>", "h4>")
|
||||
.replace("h2>", "h3>")
|
||||
.replace("h1>", "h2>"))
|
||||
# 将HTML内容转换为Markdown
|
||||
markdown_content = convert_html_to_md(new_content, "", **options)
|
||||
markdown_content = f"# {title}\n\n{markdown_content}"
|
||||
return markdown_content
|
||||
|
||||
def extract_from_document(self, doc_info: dict) -> List[Term]:
|
||||
"""从单个文档中提取专业名词"""
|
||||
try:
|
||||
# 使用LLM调用处理文档
|
||||
content = doc_info['content']
|
||||
title = doc_info["title"]
|
||||
|
||||
# 转换HTML到Markdown
|
||||
markdown_content = self._convert_html_to_md(content, title)
|
||||
|
||||
# 准备提示词
|
||||
formatted_prompt = extract_wiki_nouns_prompt.replace("{content}", markdown_content)
|
||||
formatted_prompt = formatted_prompt.replace("{output_format}", self.terms_list_parser.get_format_instructions())
|
||||
|
||||
try:
|
||||
# 调用LLM
|
||||
response = self.llm.invoke(formatted_prompt)
|
||||
# 使用Pydantic解析器解析结果
|
||||
parsed_output = self.terms_list_parser.parse(response.content)
|
||||
return parsed_output.terms
|
||||
except Exception as e:
|
||||
logging.error(f"解析LLM响应时出错: {str(e)}")
|
||||
logging.error(f"原始响应: {response.content}")
|
||||
return []
|
||||
except Exception as e:
|
||||
logging.error(f"提取专业名词时出错: {str(e)}")
|
||||
return []
|
||||
|
||||
def _process_document(self, doc, path_terms):
|
||||
"""处理单个文档"""
|
||||
try:
|
||||
# 获取信号量
|
||||
with self.semaphore:
|
||||
# 检查文档路径是否在我们要处理的路径中
|
||||
path_prefix = None
|
||||
for prefix in path_terms.keys():
|
||||
if doc['path'].startswith(prefix):
|
||||
path_prefix = prefix
|
||||
break
|
||||
|
||||
# 如果不在要处理的路径中,则跳过
|
||||
if not path_prefix:
|
||||
return None
|
||||
|
||||
# 获取文档详细信息
|
||||
doc_info = WikijsTool.query_doc_info(doc['id'])
|
||||
if not doc_info or not doc_info.get('content'):
|
||||
return None
|
||||
|
||||
# 提取专业名词
|
||||
terms = self.extract_from_document(doc_info)
|
||||
|
||||
# 将提取的术语添加到对应路径的结果列表中
|
||||
terms_dicts = [{"name": term.name, "synonymous": term.synonymous, "description": term.description} for term in terms]
|
||||
|
||||
with self.lock:
|
||||
path_terms[path_prefix].extend(terms_dicts)
|
||||
logging.info(f"文档 {doc['path']} 处理完成,提取了 {len(terms)} 个专业名词")
|
||||
|
||||
# 每处理10个文档保存一次中间结果
|
||||
current_count = len(path_terms[path_prefix])
|
||||
if current_count % 10 == 0:
|
||||
# 使用锁保护文件IO
|
||||
self._save_terms_to_file(path_terms[path_prefix], os.path.join(self.output_dir, f"{path_prefix.split('(')[0]}_nouns.json"))
|
||||
logging.info(f"已处理 {path_prefix} 的文档数达到 {current_count//10*10} 个,已保存中间结果")
|
||||
|
||||
return path_prefix
|
||||
except Exception as e:
|
||||
logging.error(f"处理文档 {doc['path']} 时出错: {str(e)}")
|
||||
return None
|
||||
|
||||
def process_all_documents(self, output_dir: str = "extracted_nouns", max_concurrency: int = 5):
|
||||
"""使用线程池处理所有文档"""
|
||||
# 保存输出目录
|
||||
self.output_dir = output_dir
|
||||
|
||||
# 创建输出目录
|
||||
if not os.path.exists(output_dir):
|
||||
os.makedirs(output_dir)
|
||||
|
||||
# 初始化信号量,限制并发请求数
|
||||
self.semaphore = Semaphore(max_concurrency)
|
||||
|
||||
# 获取所有文档
|
||||
all_docs = WikijsTool.get_all_documents()
|
||||
|
||||
# 要处理的路径前缀
|
||||
# path_prefixes = [
|
||||
# "技改检修计价通(2020)",
|
||||
# "西藏造价软件(2023)",
|
||||
# "新型储能电站建设计价通C1(2024)",
|
||||
# "配网造价软件(2022)",
|
||||
# ]
|
||||
path_prefixes = [
|
||||
"主网电力建设计价通(2018)",
|
||||
]
|
||||
# 为每个路径创建单独的结果列表
|
||||
path_terms = {prefix: [] for prefix in path_prefixes}
|
||||
|
||||
# 过滤出符合路径前缀的文档
|
||||
filtered_docs = []
|
||||
for doc in all_docs:
|
||||
for prefix in path_prefixes:
|
||||
if doc['path'].startswith(prefix):
|
||||
filtered_docs.append(doc)
|
||||
break
|
||||
|
||||
logging.info(f"开始使用线程池处理 {len(filtered_docs)} 个文档...")
|
||||
|
||||
# 使用线程池处理所有文档
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_concurrency) as executor:
|
||||
futures = []
|
||||
for doc in filtered_docs:
|
||||
future = executor.submit(self._process_document, doc, path_terms)
|
||||
futures.append(future)
|
||||
|
||||
# 等待所有任务完成
|
||||
for i, future in enumerate(concurrent.futures.as_completed(futures)):
|
||||
try:
|
||||
prefix = future.result()
|
||||
if i % 10 == 0:
|
||||
logging.info(f"已完成 {i+1}/{len(futures)} 个文档的处理")
|
||||
except Exception as e:
|
||||
logging.error(f"处理文档时出错: {str(e)}")
|
||||
|
||||
# 保存最终结果
|
||||
for prefix, terms in path_terms.items():
|
||||
# 为每个路径保存单独的文件
|
||||
output_file = os.path.join(output_dir, f"{prefix.split('(')[0]}_nouns.json")
|
||||
self._save_terms_to_file(terms, output_file)
|
||||
logging.info(f"{prefix} 处理完成,共提取 {len(terms)} 个专业名词,已保存到 {output_file}")
|
||||
|
||||
def _save_terms_to_file(self, terms, output_file):
|
||||
"""保存术语列表到文件"""
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(terms, f, ensure_ascii=False, indent=2)
|
||||
|
||||
|
||||
def main():
|
||||
# 从环境变量获取配置
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
|
||||
# os.environ["LLM_MODEL_NAME"] = "Qwen/Qwen2.5-72B-Instruct-128K"
|
||||
|
||||
extractor = WikijsNounsExtractor(api_key=api_key, base_url=base_url, model_name=os.getenv("LLM_MODEL_NAME"))
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(current_dir, "..", "..", "data", "wiki_extracted_nouns")
|
||||
extractor.process_all_documents(output_dir=output_dir, max_concurrency=2)
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置日志输出到文件,并设置格式
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||||
date_format = '%Y-%m-%d %H:%M:%S'
|
||||
|
||||
# 创建一个控制台处理器
|
||||
console_handler = logging.StreamHandler()
|
||||
console_handler.setLevel(logging.INFO)
|
||||
console_handler.setFormatter(logging.Formatter(log_format, date_format))
|
||||
|
||||
# 获取根日志记录器并添加处理器
|
||||
root_logger = logging.getLogger()
|
||||
root_logger.setLevel(logging.INFO)
|
||||
root_logger.addHandler(console_handler)
|
||||
main()
|
||||
@@ -0,0 +1,189 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: intent_recognition_example.py
|
||||
Date: 2025-05-14
|
||||
Description: 意图识别和问题改写示例
|
||||
"""
|
||||
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.intent_recognition import IntentRecognizer
|
||||
import pandas as pd
|
||||
import logging
|
||||
import json
|
||||
import concurrent.futures
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# 读取Excel文件中的提问数据
|
||||
def load_questions_from_excel(file_path=None):
|
||||
"""
|
||||
从Excel文件中读取提问数据
|
||||
|
||||
Args:
|
||||
file_path: Excel文件路径,如果为None则使用默认路径
|
||||
|
||||
Returns:
|
||||
提问列表
|
||||
"""
|
||||
|
||||
try:
|
||||
# 读取Excel文件的第一列数据
|
||||
df = pd.read_excel(file_path)
|
||||
questions = df.iloc[:, 0].tolist() # 获取第一列数据
|
||||
logging.info(f"成功从{file_path}读取了{len(questions)}条提问")
|
||||
return questions
|
||||
except Exception as e:
|
||||
logging.error(f"读取Excel文件时出错: {e}")
|
||||
return []
|
||||
|
||||
def process_query(recognizer, query):
|
||||
"""
|
||||
处理单个查询,支持重试机制
|
||||
|
||||
Args:
|
||||
recognizer: 意图识别器实例
|
||||
query: 查询字符串
|
||||
|
||||
Returns:
|
||||
处理结果字典
|
||||
"""
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 如果是重试,添加重试信息到日志
|
||||
classification, keywords, rewrite, query_keys = recognizer.process_query(query)
|
||||
|
||||
# 将keywords对象转换为字符串
|
||||
keywords_str = ""
|
||||
if keywords and keywords.terms:
|
||||
term_details = []
|
||||
for term in keywords.terms:
|
||||
term_info = {
|
||||
"名称": term.name,
|
||||
"同义词": ";".join(term.synonymous) if term.synonymous else "",
|
||||
"描述": term.description
|
||||
}
|
||||
term_details.append(term_info)
|
||||
|
||||
# 将term_details转换为JSON字符串,确保中文正确显示
|
||||
keywords_str = json.dumps(term_details, ensure_ascii=False, indent=2)
|
||||
|
||||
# 处理成功,返回结果
|
||||
return {
|
||||
"提问": query,
|
||||
"问题拆解": query_keys,
|
||||
"一级分类": classification.vertical_classification,
|
||||
"二级分类": classification.sub_classification,
|
||||
"问题改写": rewrite.rewrite,
|
||||
"检索的关键词": keywords_str
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
|
||||
# 如果已经重试了最大次数,则记录错误并返回错误结果
|
||||
if retry_count > max_retries:
|
||||
logging.error(f"处理问题 '{query}' 时出错: {e.__class__}{e}")
|
||||
return {
|
||||
"提问": query,
|
||||
"一级分类": "处理出错",
|
||||
"二级分类": "处理出错",
|
||||
"问题改写": "处理出错",
|
||||
"检索的关键词": f"重试 {max_retries} 次后失败: {str(e)}"
|
||||
}
|
||||
else:
|
||||
# 可以在这里添加延迟,避免过快重试
|
||||
time.sleep(10 * retry_count)
|
||||
|
||||
examples_query = """下载软件在哪下载?"""
|
||||
|
||||
def main():
|
||||
"""
|
||||
意图识别和问题改写示例
|
||||
"""
|
||||
|
||||
# 从环境变量中获取配置
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
|
||||
# 初始化意图识别器
|
||||
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
|
||||
# 读取提问数据
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条提问数据.xlsx")
|
||||
examples = load_questions_from_excel(data_file)
|
||||
# examples = examples_query.split("\n")
|
||||
max_workers = 20
|
||||
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
|
||||
|
||||
# 创建一个与输入顺序相同的结果列表
|
||||
results = [None] * len(examples)
|
||||
|
||||
# 使用线程池进行并发处理
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# 提交所有任务并记录它们的索引
|
||||
future_to_index = {}
|
||||
for idx, query in enumerate(examples):
|
||||
future = executor.submit(process_query, recognizer, query)
|
||||
future_to_index[future] = idx
|
||||
|
||||
# 使用tqdm显示进度条
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(examples), desc="处理进度"):
|
||||
idx = future_to_index[future]
|
||||
result = future.result()
|
||||
# 将结果放在与输入相同的位置
|
||||
results[idx] = result
|
||||
|
||||
# 将结果保存到Excel文件
|
||||
results_df = pd.DataFrame(results)
|
||||
|
||||
output_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条提问数据_重写结果.xlsx")
|
||||
|
||||
# 使用ExcelWriter设置格式
|
||||
with pd.ExcelWriter(output_file, engine='xlsxwriter') as writer:
|
||||
results_df.to_excel(writer, index=False, sheet_name='Sheet1')
|
||||
|
||||
# 获取工作簿和工作表对象
|
||||
workbook = writer.book
|
||||
worksheet = writer.sheets['Sheet1']
|
||||
|
||||
# 设置列宽(单位:像素)
|
||||
# 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位)
|
||||
worksheet.set_column('A:A', 60) # 提问列 60个Excel单位
|
||||
worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位
|
||||
worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位
|
||||
worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位
|
||||
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
|
||||
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
|
||||
|
||||
# 设置所有行高为20磅
|
||||
for i in range(len(results_df) + 1): # +1 是为了包括表头
|
||||
worksheet.set_row(i, 20)
|
||||
|
||||
logging.info(f"处理完成,结果已保存至: {output_file}")
|
||||
|
||||
def setup_logging():
|
||||
# 配置日志输出到控制台
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler() # 添加控制台处理器
|
||||
]
|
||||
)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
if __name__ == "__main__":
|
||||
setup_logging()
|
||||
logging.info("意图识别示例程序开始运行...")
|
||||
main()
|
||||
@@ -0,0 +1,293 @@
|
||||
"""
|
||||
答案正确性评判工具
|
||||
|
||||
此模块用于评判问题的新旧回答是否正确,通过与标准答案(Wiki内容)进行比较,
|
||||
或者在没有标准答案的情况下比较新旧回答的差异。
|
||||
|
||||
用法示例:
|
||||
judge = AnswerCorrectnessJudge()
|
||||
judge.process()
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
from urllib.parse import unquote
|
||||
from rag2_0.tool.WikijsTool import WikijsTool
|
||||
from rag2_0.tool.html_to_md import convert_html_to_md
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
load_dotenv()
|
||||
|
||||
class AnswerCorrectnessJudge:
|
||||
"""
|
||||
答案正确性评判工具类
|
||||
|
||||
用于评估问题的新旧回答是否正确,可以通过与标准答案(Wiki内容)进行比较,
|
||||
或者在没有标准答案的情况下比较新旧回答的差异。
|
||||
"""
|
||||
|
||||
def __init__(self, wiki_excel_path="/data/Rag2_0/data/excel/部分提问_软件名称明确.xlsx",
|
||||
answer_excel_path="/data/Rag2_0/data/excel/主网软件提问_对比结果.xlsx",
|
||||
output_path="/data/Rag2_0/data/excel/主网软件提问回答_判断结果.xlsx"):
|
||||
"""
|
||||
初始化答案正确性评判工具
|
||||
|
||||
参数:
|
||||
wiki_excel_path (str): Wiki Excel文件路径
|
||||
answer_excel_path (str): 答案对比Excel文件路径
|
||||
output_path (str): 输出Excel文件路径
|
||||
"""
|
||||
self.wiki_excel_path = wiki_excel_path
|
||||
self.answer_excel_path = answer_excel_path
|
||||
self.output_path = output_path
|
||||
|
||||
# 读取Excel文件
|
||||
self.wiki_excel = pd.read_excel(self.wiki_excel_path)
|
||||
self.answer_excel = pd.read_excel(self.answer_excel_path)
|
||||
|
||||
# 初始化LLM
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = os.getenv("OPENAI_API_BASE")
|
||||
self.model = os.getenv("LLM_MODEL_NAME")
|
||||
|
||||
if not all([self.api_key, self.base_url, self.model]):
|
||||
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
|
||||
|
||||
self.openai_llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model)
|
||||
|
||||
def find_wiki_link(self, query) -> str | None:
|
||||
"""
|
||||
根据查询(对应wiki_excel中的新提问列)找出对应的词条链接
|
||||
|
||||
参数:
|
||||
query (str): 查询内容,对应wiki_excel中的新提问列
|
||||
|
||||
返回:
|
||||
str: 对应的词条链接,如果没有找到则返回None
|
||||
"""
|
||||
# 确保query不为空
|
||||
if not query or pd.isna(query):
|
||||
return None
|
||||
|
||||
# 在"新提问"列中查找匹配的行
|
||||
matched_rows = self.wiki_excel[self.wiki_excel['新提问'] == query]
|
||||
|
||||
# 如果找到了匹配的行,返回对应的词条链接
|
||||
if not matched_rows.empty:
|
||||
return matched_rows.iloc[0]['对应词条链接']
|
||||
|
||||
# 如果没有完全匹配,尝试部分匹配
|
||||
# 去除软件名称部分(如果有)
|
||||
query_parts = query.split(',', 1)
|
||||
if len(query_parts) > 1:
|
||||
clean_query = query_parts[1].strip()
|
||||
|
||||
# 在"提问"列中查找包含清理后查询的行
|
||||
for idx, row in self.wiki_excel.iterrows():
|
||||
if pd.notna(row['提问']) and clean_query in row['提问']:
|
||||
return row['对应词条链接']
|
||||
|
||||
return None
|
||||
|
||||
def get_wiki_content(self, link) -> str:
|
||||
"""
|
||||
获取词条链接的内容
|
||||
|
||||
参数:
|
||||
link (str): 词条链接
|
||||
|
||||
返回:
|
||||
str: 链接内容,如果获取失败则返回错误信息
|
||||
"""
|
||||
try:
|
||||
if not link or pd.isna(link):
|
||||
return "链接为空或无效"
|
||||
# 移除域名部分,只保留路径
|
||||
path = link.split('/', 3)[-1]
|
||||
decoded_path = unquote(path)
|
||||
path_parts = decoded_path.split('/')
|
||||
doc_path = "/".join(path_parts[1:])
|
||||
wiki_doc = WikijsTool.get_all_doc_by_path(path=doc_path, path_is_dir=False)
|
||||
html_content = WikijsTool.query_doc_info(wiki_doc[0]["id"]).get('content')
|
||||
if not html_content:
|
||||
return "获取内容失败"
|
||||
|
||||
options = {"heading_style": '', "keep_inline_images_in": ["figure", "img"], "escape_asterisks": True}
|
||||
new_content = (html_content.replace("h6>", "h7>")
|
||||
.replace("h5>", "h6>")
|
||||
.replace("h4>", "h5>")
|
||||
.replace("h3>", "h4>")
|
||||
.replace("h2>", "h3>")
|
||||
.replace("h1>", "h2>"))
|
||||
# 将HTML内容转换为Markdown
|
||||
markdown_content = convert_html_to_md(new_content, "", **options)
|
||||
markdown_content = f"# {path_parts[-1]}\n\n{markdown_content}"
|
||||
return markdown_content
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
|
||||
|
||||
def create_prompt(self, standard_answer: str, answer_to_check: str) -> str:
|
||||
"""
|
||||
创建用于评判答案的prompt
|
||||
|
||||
参数:
|
||||
standard_answer (str): 标准答案
|
||||
answer_to_check (str): 需要检查的答案
|
||||
|
||||
返回:
|
||||
str: 格式化的prompt
|
||||
"""
|
||||
return f"""请作为一个专业的答案评判专家,评估以下回答与标准答案的匹配程度。
|
||||
|
||||
标准答案:
|
||||
{standard_answer}
|
||||
|
||||
待评估的回答:
|
||||
{answer_to_check}
|
||||
|
||||
请仔细分析两个答案的内容,并给出你的判断。只需要回答"正确"或"错误",不需要其他解释。
|
||||
如果待评估的回答与标准答案在核心内容和关键信息(步骤)上一致,即使表达方式不同,也应判定为"正确"。
|
||||
如果待评估的回答存在明显的错误信息或重要信息缺失,应判定为"错误"。
|
||||
|
||||
请严格按以下格式输出:【正确】或【错误】:"""
|
||||
|
||||
def judge_old_answer(self, standard_answer: str, old_answer: str) -> bool | None:
|
||||
"""
|
||||
调用LLM判断旧回答是否正确
|
||||
|
||||
参数:
|
||||
standard_answer (str): 标准答案(来自Wiki)
|
||||
old_answer (str): 旧流程的回答
|
||||
|
||||
返回:
|
||||
bool | None: 判断结果,True表示正确,False表示错误,None表示判断失败
|
||||
"""
|
||||
prompt = self.create_prompt(standard_answer, old_answer)
|
||||
try:
|
||||
response = self.openai_llm.invoke(prompt)
|
||||
return "正确" in response.content
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def judge_new_answer(self, standard_answer: str, new_answer: str) -> bool | None:
|
||||
"""
|
||||
调用LLM判断新回答是否正确
|
||||
|
||||
参数:
|
||||
standard_answer (str): 标准答案(来自Wiki)
|
||||
new_answer (str): 新流程的回答
|
||||
|
||||
返回:
|
||||
bool | None: 判断结果,True表示正确,False表示错误,None表示判断失败
|
||||
"""
|
||||
prompt = self.create_prompt(standard_answer, new_answer)
|
||||
try:
|
||||
response = self.openai_llm.invoke(prompt)
|
||||
return "正确" in response.content
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def judge_by_standard_answer(self, standard_answer: str, old_answer: str, new_answer: str) -> str | None:
|
||||
"""
|
||||
综合判断新旧回答的正确性
|
||||
|
||||
参数:
|
||||
standard_answer (str): 标准答案(来自Wiki)
|
||||
old_answer (str): 旧流程的回答
|
||||
new_answer (str): 新流程的回答
|
||||
|
||||
返回:
|
||||
str | None: 包含新旧回答判断结果的字符串,None表示判断失败
|
||||
"""
|
||||
old_result = self.judge_old_answer(standard_answer, old_answer)
|
||||
new_result = self.judge_new_answer(standard_answer, new_answer)
|
||||
if old_result is None or new_result is None:
|
||||
return None
|
||||
if new_result and old_result:
|
||||
return "新旧答案均正确"
|
||||
elif new_result and not old_result:
|
||||
return "新答案正确"
|
||||
elif not new_result and old_result:
|
||||
return "旧答案正确"
|
||||
else:
|
||||
return "新旧答案均错误"
|
||||
|
||||
def judge_answer_diff(self, old_answer: str, new_answer: str) -> str | None:
|
||||
"""
|
||||
判断新旧回答是否存在较大差异
|
||||
|
||||
参数:
|
||||
old_answer (str): 旧流程的回答
|
||||
new_answer (str): 新流程的回答
|
||||
|
||||
返回:
|
||||
str | None: 差异判断结果,None表示判断失败
|
||||
"""
|
||||
|
||||
prompt = f"""请判断以下两个回答是否存在较大差异:
|
||||
|
||||
旧回答: {old_answer}
|
||||
|
||||
新回答: {new_answer}
|
||||
|
||||
主要是关键步骤、关键信息、或者关键主体的差异
|
||||
请仅回答"存在较大差异"或"差异较小"。"""
|
||||
|
||||
try:
|
||||
response = self.openai_llm.invoke(prompt)
|
||||
return "无法判断,新老答案差异较大" if "存在较大差异" in response.content else "无法判断,新老答案基本相同"
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
处理所有问题并评判答案正确性
|
||||
|
||||
读取Excel文件中的问题和答案,进行评判,并将结果保存到输出Excel文件
|
||||
"""
|
||||
# 创建结果列表
|
||||
results = []
|
||||
|
||||
# 读取Excel文件
|
||||
for idx, row in tqdm(self.answer_excel.iterrows(), total=len(self.answer_excel), desc="处理问题"):
|
||||
query = row["问题"]
|
||||
old_answer = row["旧流程答案"]
|
||||
new_answer = row["新流程答案"]
|
||||
standard_answer = ""
|
||||
|
||||
try:
|
||||
wiki_url = self.find_wiki_link(query)
|
||||
if wiki_url and not pd.isna(wiki_url):
|
||||
standard_answer = self.get_wiki_content(wiki_url)
|
||||
except Exception as e:
|
||||
print(f"处理问题 '{query}' 时发生错误: {str(e)}")
|
||||
|
||||
if standard_answer:
|
||||
# 判断答案正确性
|
||||
judge_result = self.judge_by_standard_answer(standard_answer, old_answer, new_answer)
|
||||
else:
|
||||
judge_result = self.judge_answer_diff(old_answer, new_answer)
|
||||
|
||||
if judge_result is None:
|
||||
judge_result = ""
|
||||
|
||||
results.append({
|
||||
"问题": query,
|
||||
"旧流程答案": old_answer,
|
||||
"新流程答案": new_answer,
|
||||
"判断结果": judge_result
|
||||
})
|
||||
|
||||
# 将结果转换为DataFrame并保存
|
||||
results_df = pd.DataFrame(results)
|
||||
results_df.to_excel(self.output_path, index=False)
|
||||
print(f"处理完成,共处理 {len(results)} 条记录,结果已保存至 {self.output_path}")
|
||||
|
||||
# 测试函数
|
||||
if __name__ == "__main__":
|
||||
# 创建答案正确性评判工具实例
|
||||
judge = AnswerCorrectnessJudge()
|
||||
# 执行处理
|
||||
judge.process()
|
||||
@@ -0,0 +1,615 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
完整性问题判断工具
|
||||
|
||||
此脚本用于读取Excel文件中的问题,调用LLM判断问题是否完整,并将结果保存到Excel文件中。
|
||||
|
||||
用法示例:
|
||||
python judge_query_full.py -i "问题数据.xlsx" -o "完整问题结果.xlsx" -w 50 -c 0
|
||||
|
||||
命令行参数:
|
||||
-i, --input: 输入Excel文件路径
|
||||
-o, --output: 输出Excel文件路径
|
||||
-w, --workers: 并发处理的最大线程数
|
||||
-c, --column: 要处理的问题所在列的索引(从0开始)
|
||||
-t, --test: 测试单个问题,不处理Excel文件
|
||||
"""
|
||||
|
||||
import pandas as pd
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
import re
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from rag2_0.tool.APIKeyManager import APIKeyManager
|
||||
from openpyxl.utils import get_column_letter
|
||||
from openpyxl.styles import Alignment, PatternFill, Font, Border, Side
|
||||
from tqdm import tqdm
|
||||
import concurrent.futures
|
||||
import threading
|
||||
|
||||
# 默认设置
|
||||
DEFAULT_EXCEL_PATH = r"/data/Rag2_0/data/excel/7000条对话数据.xlsx"
|
||||
DEFAULT_OUTPUT_PATH = r"/data/Rag2_0/data/excel/7000条对话数据_完整问题结果.xlsx"
|
||||
DEFAULT_MAX_WORKERS = 50
|
||||
|
||||
|
||||
class QueryCompletenessJudge:
|
||||
"""
|
||||
问题完整性判断工具类
|
||||
|
||||
用于评估问题是否完整,并将结果保存到Excel文件中。
|
||||
可以批量处理Excel文件中的问题,也可以测试单个问题。
|
||||
"""
|
||||
|
||||
def __init__(self, input_path=DEFAULT_EXCEL_PATH, output_path=DEFAULT_OUTPUT_PATH,
|
||||
max_workers=DEFAULT_MAX_WORKERS, column_index=0):
|
||||
"""
|
||||
初始化问题完整性判断工具
|
||||
|
||||
参数:
|
||||
input_path (str): 输入Excel文件路径
|
||||
output_path (str): 输出Excel文件路径
|
||||
max_workers (int): 并发处理的最大线程数
|
||||
column_index (int): 要处理的问题所在列的索引(从0开始)
|
||||
"""
|
||||
self.input_path = input_path
|
||||
self.output_path = output_path
|
||||
self.max_workers = max_workers
|
||||
self.column_index = column_index
|
||||
self.llm_client = self._create_llm_client()
|
||||
|
||||
def _extract_json_from_response(self, full_answer):
|
||||
"""
|
||||
从LLM响应中提取JSON部分
|
||||
|
||||
参数:
|
||||
full_answer (str): LLM的完整响应文本
|
||||
|
||||
返回:
|
||||
dict: 解析后的JSON对象,如果解析失败则返回None
|
||||
"""
|
||||
# 尝试从回答中提取JSON部分
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', full_answer, re.DOTALL)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
# 如果没有找到```json```格式,尝试寻找普通的JSON对象
|
||||
json_match = re.search(r'({[\s\S]*"is_complete"[\s\S]*})', full_answer)
|
||||
if json_match:
|
||||
json_str = json_match.group(1)
|
||||
else:
|
||||
# 如果仍然没有找到,返回None
|
||||
return None
|
||||
|
||||
try:
|
||||
# 解析JSON
|
||||
return json.loads(json_str)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
|
||||
def _create_llm_prompt(self, question):
|
||||
"""
|
||||
创建LLM提示词
|
||||
|
||||
参数:
|
||||
question (str): 需要判断完整性的问题
|
||||
|
||||
返回:
|
||||
str: 格式化后的提示词
|
||||
"""
|
||||
return f"""你是一个电力造价行业专家,用户正在使用电力造价软件,并提出了相关问题。请分析以下问题是否完整。
|
||||
|
||||
问题:{question}
|
||||
|
||||
首先,分析这个问题的结构和内容,思考它是否包含足够的信息来表达清晰的意图。
|
||||
考虑以下几点:
|
||||
1. 问题是否有明确的核心意图,不需要面面俱到
|
||||
2. 问题是否缺少必要的上下文
|
||||
3. **问题如果涉及软件相关,则只需要包含:软件名称、软件功能或软件目的即可**
|
||||
|
||||
|
||||
在你的分析之后,请用JSON格式给出最终结论,格式如下:
|
||||
```json
|
||||
{{
|
||||
"is_complete": true或false,
|
||||
"reason": "判断原因的简要说明",
|
||||
"confidence": 0到100之间的数值,表示你对判断的置信度
|
||||
}}
|
||||
```
|
||||
|
||||
请确保JSON格式正确,以便于程序解析。"""
|
||||
|
||||
def _create_llm_client(self, api_key=None):
|
||||
"""
|
||||
创建LLM客户端
|
||||
|
||||
参数:
|
||||
api_key (str, optional): API密钥,如果为None则从APIKeyManager获取
|
||||
|
||||
返回:
|
||||
OpenAiLLM: LLM客户端实例
|
||||
"""
|
||||
if api_key is None:
|
||||
api_key = APIKeyManager.get_api_key()
|
||||
|
||||
return OpenAiLLM(
|
||||
api_key=api_key,
|
||||
base_url="https://api.siliconflow.cn/v1", # 可以根据实际情况修改
|
||||
model="deepseek-ai/DeepSeek-V3", # 可以根据实际情况修改
|
||||
temperature=0.2,
|
||||
max_tokens=100
|
||||
)
|
||||
|
||||
def is_question_complete(self, question):
|
||||
"""
|
||||
调用LLM判断问题是否完整
|
||||
|
||||
参数:
|
||||
question (str): 需要判断的问题
|
||||
|
||||
返回:
|
||||
tuple: (bool, str) - 是否完整的布尔值和LLM的详细回复
|
||||
"""
|
||||
# 最大重试次数
|
||||
max_retries = 3
|
||||
retry_count = 0
|
||||
retry_delay = 2 # 重试延迟,单位:秒
|
||||
|
||||
while retry_count <= max_retries:
|
||||
try:
|
||||
# 创建提示词
|
||||
prompt = self._create_llm_prompt(question)
|
||||
|
||||
# 使用OpenAiLLM调用模型
|
||||
response = self.llm_client.invoke(prompt)
|
||||
|
||||
# 处理可能的响应格式
|
||||
if hasattr(response, 'content'):
|
||||
full_answer = response.content
|
||||
else:
|
||||
# 如果response是字符串
|
||||
full_answer = str(response)
|
||||
|
||||
# 提取JSON部分
|
||||
result = self._extract_json_from_response(full_answer)
|
||||
|
||||
if result:
|
||||
is_complete = result.get("is_complete", False)
|
||||
return is_complete, full_answer
|
||||
else:
|
||||
# 如果没有找到或解析失败,使用简单判断
|
||||
is_complete = "完整" in full_answer[:100]
|
||||
return is_complete, full_answer
|
||||
|
||||
except Exception as e:
|
||||
retry_count += 1
|
||||
if retry_count <= max_retries:
|
||||
# 非最后一次重试,打印错误并继续
|
||||
time.sleep(retry_delay)
|
||||
# 每次重试增加延迟时间,避免频繁失败
|
||||
retry_delay *= 2
|
||||
else:
|
||||
# 已达到最大重试次数,返回错误
|
||||
print(f"错误: 经过 {max_retries} 次重试后仍然失败: {str(e)}")
|
||||
return False, f"错误: 经过 {max_retries} 次重试后仍然失败: {str(e)}"
|
||||
|
||||
# 不应该到达这里,但为了代码完整性添加
|
||||
return False, "未知错误:重试机制逻辑错误"
|
||||
|
||||
def _process_question(self, args, complete_questions, progress_counter, progress_lock, complete_questions_lock, pbar):
|
||||
"""
|
||||
处理单个问题并更新进度
|
||||
|
||||
参数:
|
||||
args (tuple): 包含问题索引、问题内容、LLM客户端和总问题数的元组
|
||||
complete_questions (list): 存储完整问题的列表
|
||||
progress_counter (dict): 进度计数器
|
||||
progress_lock (threading.Lock): 进度锁
|
||||
complete_questions_lock (threading.Lock): 完整问题列表锁
|
||||
pbar (tqdm): 进度条对象
|
||||
"""
|
||||
index, question, llm_client, total_questions = args
|
||||
|
||||
# 跳过空问题
|
||||
if pd.isna(question) or question.strip() == "":
|
||||
with progress_lock:
|
||||
progress_counter["processed"] += 1
|
||||
pbar.update(1)
|
||||
return None
|
||||
|
||||
# 调用LLM判断问题是否完整
|
||||
is_complete, full_answer = self.is_question_complete(question)
|
||||
|
||||
if is_complete:
|
||||
# 从答案中提取JSON
|
||||
parsed_json = self._extract_json_from_response(full_answer)
|
||||
|
||||
if parsed_json:
|
||||
# 构造包含解析出的JSON信息的结果
|
||||
result = {
|
||||
"问题": question,
|
||||
"LLM回复": full_answer,
|
||||
"完整性": "完整" if parsed_json.get("is_complete", False) else "不完整",
|
||||
"原因": parsed_json.get("reason", "未提供"),
|
||||
"置信度": parsed_json.get("confidence", 0)
|
||||
}
|
||||
|
||||
# 更新计数
|
||||
with progress_lock:
|
||||
if result["完整性"] == "完整":
|
||||
progress_counter["complete"] += 1
|
||||
else:
|
||||
progress_counter["incomplete"] += 1
|
||||
else:
|
||||
# JSON解析失败,只保存原始回答
|
||||
result = {
|
||||
"问题": question,
|
||||
"LLM回复": full_answer,
|
||||
"完整性": "完整"
|
||||
}
|
||||
|
||||
# 更新计数
|
||||
with progress_lock:
|
||||
progress_counter["complete"] += 1
|
||||
|
||||
with complete_questions_lock:
|
||||
complete_questions.append(result)
|
||||
else:
|
||||
with progress_lock:
|
||||
progress_counter["incomplete"] += 1
|
||||
# 更新进度条
|
||||
with progress_lock:
|
||||
progress_counter["processed"] += 1
|
||||
# 更新进度条描述
|
||||
pbar.set_postfix(
|
||||
完整=progress_counter["complete"],
|
||||
不完整=progress_counter["incomplete"],
|
||||
完整率=f"{progress_counter['complete']/max(1, progress_counter['processed']):.1%}"
|
||||
)
|
||||
pbar.update(1)
|
||||
|
||||
def _shorten_response(self, response):
|
||||
"""
|
||||
截断LLM响应,提取重要信息
|
||||
|
||||
参数:
|
||||
response (str): 原始LLM响应
|
||||
|
||||
返回:
|
||||
str: 截断后的响应
|
||||
"""
|
||||
# 保留思考过程的前200个字符和JSON部分
|
||||
json_match = re.search(r'```json\s*(.*?)\s*```', response, re.DOTALL)
|
||||
if json_match:
|
||||
json_part = json_match.group(0)
|
||||
prefix = response[:200] + "..." if len(response) > 200 else response
|
||||
return f"{prefix}\n\n{json_part}"
|
||||
return response[:500] + "..." if len(response) > 500 else response
|
||||
|
||||
def _prepare_excel_dataframe(self, complete_questions):
|
||||
"""
|
||||
将结果处理为DataFrame用于Excel输出
|
||||
|
||||
参数:
|
||||
complete_questions (list): 完整问题列表
|
||||
|
||||
返回:
|
||||
pandas.DataFrame: 处理后的DataFrame
|
||||
"""
|
||||
# 将结果列表转换为DataFrame
|
||||
result_df = pd.DataFrame(complete_questions)
|
||||
|
||||
# 处理LLM回复列,截取一定长度以避免Excel单元格过大
|
||||
if "LLM回复" in result_df.columns:
|
||||
result_df["LLM回复"] = result_df["LLM回复"].apply(self._shorten_response)
|
||||
|
||||
# 调整列的顺序,确保重要列在前面
|
||||
column_order = ["问题", "完整性", "置信度", "原因", "LLM回复"]
|
||||
# 过滤掉不存在的列
|
||||
column_order = [col for col in column_order if col in result_df.columns]
|
||||
# 确保所有剩余的列也被包含
|
||||
for col in result_df.columns:
|
||||
if col not in column_order:
|
||||
column_order.append(col)
|
||||
|
||||
# 重新排序列
|
||||
return result_df[column_order]
|
||||
|
||||
def _set_excel_column_widths(self, worksheet):
|
||||
"""
|
||||
设置Excel列宽
|
||||
|
||||
参数:
|
||||
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
||||
"""
|
||||
for col in range(1, worksheet.max_column + 1):
|
||||
col_letter = get_column_letter(col)
|
||||
column_name = worksheet[f"{col_letter}1"].value
|
||||
|
||||
if column_name == "问题":
|
||||
worksheet.column_dimensions[col_letter].width = 40
|
||||
elif column_name == "LLM回复":
|
||||
worksheet.column_dimensions[col_letter].width = 60
|
||||
elif column_name == "原因":
|
||||
worksheet.column_dimensions[col_letter].width = 30
|
||||
elif column_name == "完整性":
|
||||
worksheet.column_dimensions[col_letter].width = 10
|
||||
elif column_name == "置信度":
|
||||
worksheet.column_dimensions[col_letter].width = 10
|
||||
else:
|
||||
worksheet.column_dimensions[col_letter].width = 15
|
||||
|
||||
def _apply_excel_cell_styles(self, worksheet):
|
||||
"""
|
||||
应用单元格样式
|
||||
|
||||
参数:
|
||||
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
||||
|
||||
返回:
|
||||
openpyxl.styles.Border: 边框样式,用于统计信息
|
||||
"""
|
||||
# 定义样式
|
||||
header_fill = PatternFill(start_color="DDEBF7", end_color="DDEBF7", fill_type="solid")
|
||||
header_font = Font(bold=True)
|
||||
wrap_alignment = Alignment(wrap_text=True, vertical="top")
|
||||
border = Border(
|
||||
left=Side(style='thin'),
|
||||
right=Side(style='thin'),
|
||||
top=Side(style='thin'),
|
||||
bottom=Side(style='thin')
|
||||
)
|
||||
|
||||
# 应用样式到每个单元格
|
||||
for row in worksheet.iter_rows(min_row=1, max_row=worksheet.max_row, min_col=1, max_col=worksheet.max_column):
|
||||
for cell in row:
|
||||
cell.alignment = wrap_alignment
|
||||
cell.border = border
|
||||
|
||||
# 为标题行应用特殊样式
|
||||
if cell.row == 1:
|
||||
cell.fill = header_fill
|
||||
cell.font = header_font
|
||||
|
||||
# 为完整性列应用条件格式
|
||||
if cell.row > 1: # 跳过标题行
|
||||
column_name = worksheet.cell(row=1, column=cell.column).value
|
||||
if column_name == "完整性":
|
||||
if cell.value == "完整":
|
||||
cell.fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
|
||||
else:
|
||||
cell.fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
|
||||
|
||||
return border # 返回边框样式以便在统计信息中重用
|
||||
|
||||
def _add_statistics_to_excel(self, worksheet, complete_questions, total_rows, total_questions, border):
|
||||
"""
|
||||
添加统计信息到Excel表格
|
||||
|
||||
参数:
|
||||
worksheet (openpyxl.worksheet.worksheet.Worksheet): Excel工作表
|
||||
complete_questions (list): 完整问题列表
|
||||
total_rows (int): 总行数
|
||||
total_questions (int): 总问题数
|
||||
border (openpyxl.styles.Border): 边框样式
|
||||
|
||||
返回:
|
||||
int: 完整问题数量
|
||||
"""
|
||||
# 计算统计数据
|
||||
complete_count = sum(1 for item in complete_questions if item.get("完整性") == "完整")
|
||||
incomplete_count = total_rows - complete_count
|
||||
|
||||
# 添加统计行
|
||||
worksheet.append([""]) # 空行
|
||||
|
||||
stat_row = worksheet.max_row + 1
|
||||
worksheet.cell(row=stat_row, column=1, value="统计信息")
|
||||
worksheet.cell(row=stat_row, column=1).font = Font(bold=True)
|
||||
|
||||
worksheet.cell(row=stat_row+1, column=1, value="总问题数")
|
||||
worksheet.cell(row=stat_row+1, column=2, value=total_rows)
|
||||
|
||||
worksheet.cell(row=stat_row+2, column=1, value="完整问题数")
|
||||
worksheet.cell(row=stat_row+2, column=2, value=complete_count)
|
||||
worksheet.cell(row=stat_row+2, column=2).fill = PatternFill(start_color="C6EFCE", end_color="C6EFCE", fill_type="solid")
|
||||
|
||||
worksheet.cell(row=stat_row+3, column=1, value="不完整问题数")
|
||||
worksheet.cell(row=stat_row+3, column=2, value=incomplete_count)
|
||||
worksheet.cell(row=stat_row+3, column=2).fill = PatternFill(start_color="FFC7CE", end_color="FFC7CE", fill_type="solid")
|
||||
|
||||
worksheet.cell(row=stat_row+4, column=1, value="完整问题比例")
|
||||
worksheet.cell(row=stat_row+4, column=2, value=f"{complete_count/total_rows:.2%}" if total_rows > 0 else "0%")
|
||||
|
||||
# 应用边框到统计行
|
||||
for r in range(stat_row, stat_row+5):
|
||||
for c in range(1, 3):
|
||||
worksheet.cell(row=r, column=c).border = border
|
||||
|
||||
return complete_count
|
||||
|
||||
def save_results_to_excel(self, complete_questions, total_questions):
|
||||
"""
|
||||
将结果保存到Excel文件
|
||||
|
||||
参数:
|
||||
complete_questions (list): 完整问题列表
|
||||
total_questions (int): 总问题数
|
||||
"""
|
||||
if not complete_questions:
|
||||
print(f"没有找到完整的问题。")
|
||||
return
|
||||
|
||||
# 准备数据
|
||||
result_df = self._prepare_excel_dataframe(complete_questions)
|
||||
total_rows = len(result_df)
|
||||
|
||||
# 保存到Excel文件
|
||||
result_df.to_excel(self.output_path, index=False, engine='openpyxl')
|
||||
|
||||
# 应用Excel样式
|
||||
from openpyxl import load_workbook
|
||||
wb = load_workbook(self.output_path)
|
||||
ws = wb.active
|
||||
|
||||
# 设置列宽
|
||||
self._set_excel_column_widths(ws)
|
||||
|
||||
# 应用单元格样式
|
||||
border = self._apply_excel_cell_styles(ws)
|
||||
|
||||
# 添加统计信息
|
||||
complete_count = self._add_statistics_to_excel(ws, complete_questions, total_rows, total_questions, border)
|
||||
|
||||
# 保存样式化的工作簿
|
||||
wb.save(self.output_path)
|
||||
|
||||
# 输出结果统计
|
||||
print(f"处理完成。共有{complete_count}/{total_questions}个完整问题被保存到 {self.output_path}")
|
||||
print(f"完整问题比例: {complete_count/total_questions:.2%}" if total_questions > 0 else "完整问题比例: 0%")
|
||||
|
||||
def process_excel_file(self):
|
||||
"""
|
||||
处理Excel文件中的问题
|
||||
|
||||
读取Excel文件,判断问题完整性,并将结果保存到输出Excel文件
|
||||
"""
|
||||
# 确保Excel文件存在
|
||||
if not os.path.exists(self.input_path):
|
||||
print(f"错误: 找不到Excel文件 '{self.input_path}'")
|
||||
return
|
||||
|
||||
# 读取Excel文件
|
||||
print(f"正在读取Excel文件: {self.input_path}")
|
||||
try:
|
||||
df = pd.read_excel(self.input_path)
|
||||
except Exception as e:
|
||||
print(f"读取Excel文件时出错: {e}")
|
||||
return
|
||||
|
||||
# 检查列数据
|
||||
if len(df.columns) <= self.column_index:
|
||||
print(f"错误: Excel文件没有足够的列,请求索引 {self.column_index},但只有 {len(df.columns)} 列")
|
||||
return
|
||||
|
||||
# 获取目标列名称
|
||||
target_col = df.columns[self.column_index]
|
||||
print(f"目标列名称: {target_col}")
|
||||
|
||||
# 准备存储完整问题的列表
|
||||
complete_questions = []
|
||||
total_questions = len(df)
|
||||
|
||||
print(f"总共有{total_questions}个问题需要判断")
|
||||
|
||||
# 用于线程安全的列表操作和进度计数
|
||||
complete_questions_lock = threading.Lock()
|
||||
progress_counter = {"processed": 0, "complete": 0, "incomplete": 0}
|
||||
progress_lock = threading.Lock()
|
||||
|
||||
# 准备问题列表
|
||||
questions = [(i, str(row[target_col]), self.llm_client, total_questions)
|
||||
for i, row in df.iterrows()]
|
||||
|
||||
# 记录开始时间
|
||||
start_time = time.time()
|
||||
|
||||
# 使用tqdm创建进度条
|
||||
print(f"开始处理问题,使用 {self.max_workers} 个并发线程...")
|
||||
with tqdm(total=total_questions, desc="处理问题", unit="问题") as pbar:
|
||||
# 使用线程池并发处理
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=self.max_workers) as executor:
|
||||
# 提交所有任务
|
||||
futures = [executor.submit(
|
||||
self._process_question,
|
||||
args,
|
||||
complete_questions,
|
||||
progress_counter,
|
||||
progress_lock,
|
||||
complete_questions_lock,
|
||||
pbar
|
||||
) for args in questions]
|
||||
|
||||
# 等待所有任务完成
|
||||
concurrent.futures.wait(futures)
|
||||
|
||||
# 计算总处理时间
|
||||
processing_time = time.time() - start_time
|
||||
print(f"处理完成,耗时: {processing_time:.2f}秒,平均每问题: {processing_time/total_questions:.2f}秒")
|
||||
|
||||
# 将完整问题保存到Excel文件
|
||||
self.save_results_to_excel(complete_questions, total_questions)
|
||||
|
||||
def test_single_question(self, question):
|
||||
"""
|
||||
测试单个问题的完整性
|
||||
|
||||
参数:
|
||||
question (str): 要测试的问题
|
||||
"""
|
||||
print(f"问题: {question}")
|
||||
print("正在调用LLM判断问题是否完整...")
|
||||
|
||||
# 调用LLM判断问题是否完整
|
||||
is_complete, full_answer = self.is_question_complete(question)
|
||||
|
||||
# 从答案中提取JSON
|
||||
parsed_json = self._extract_json_from_response(full_answer)
|
||||
|
||||
print("\n==== LLM回复 ====")
|
||||
print(full_answer)
|
||||
print("================\n")
|
||||
|
||||
if parsed_json:
|
||||
print(f"判断结果: {'完整' if parsed_json.get('is_complete', False) else '不完整'}")
|
||||
print(f"判断原因: {parsed_json.get('reason', '未提供')}")
|
||||
print(f"置信度: {parsed_json.get('confidence', 0)}%")
|
||||
else:
|
||||
print(f"判断结果: {'完整' if is_complete else '不完整'} (简单判断)")
|
||||
print("无法从回复中提取JSON结构化数据")
|
||||
|
||||
|
||||
def parse_arguments():
|
||||
"""解析命令行参数"""
|
||||
parser = argparse.ArgumentParser(description='判断Excel文件中的问题是否完整')
|
||||
parser.add_argument('-i', '--input', type=str, default=DEFAULT_EXCEL_PATH,
|
||||
help=f'输入Excel文件路径 (默认: {DEFAULT_EXCEL_PATH})')
|
||||
parser.add_argument('-o', '--output', type=str, default=DEFAULT_OUTPUT_PATH,
|
||||
help=f'输出Excel文件路径 (默认: {DEFAULT_OUTPUT_PATH})')
|
||||
parser.add_argument('-w', '--workers', type=int, default=DEFAULT_MAX_WORKERS,
|
||||
help=f'并发处理的最大线程数 (默认: {DEFAULT_MAX_WORKERS})')
|
||||
parser.add_argument('-c', '--column', type=int, default=0,
|
||||
help='要处理的问题所在列的索引 (默认: 0,即第一列)')
|
||||
parser.add_argument('-t', '--test', type=str,
|
||||
help='测试单个问题,不处理Excel文件')
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
args = parse_arguments()
|
||||
|
||||
# 创建问题完整性判断工具实例
|
||||
judge = QueryCompletenessJudge(
|
||||
input_path=args.input,
|
||||
output_path=args.output,
|
||||
max_workers=args.workers,
|
||||
column_index=args.column
|
||||
)
|
||||
# 如果是测试单个问题
|
||||
if args.test:
|
||||
judge.test_single_question(args.test)
|
||||
return
|
||||
|
||||
# 处理Excel文件
|
||||
judge.process_excel_file()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -0,0 +1,239 @@
|
||||
import pandas as pd
|
||||
from urllib.parse import unquote
|
||||
from rag2_0.tool.WikijsTool import WikijsTool
|
||||
from rag2_0.tool.html_to_md import convert_html_to_md
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
from rag2_0.dify.dify_tool import DifyTool
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
load_dotenv()
|
||||
|
||||
class ContentSource(BaseModel):
|
||||
score:int = Field(description="相关性分数")
|
||||
reason:str = Field(description="评分理由")
|
||||
|
||||
class RetrieveContentScoreJudge:
|
||||
"""
|
||||
检索内容相关性评分工具类
|
||||
|
||||
用于评估检索内容与问题之间的相关性,并计算相关性分数
|
||||
"""
|
||||
|
||||
def __init__(self, wiki_excel_path, answer_excel_path, output_path=None):
|
||||
"""
|
||||
初始化评分工具类
|
||||
|
||||
参数:
|
||||
wiki_excel_path (str): Wiki Excel文件路径
|
||||
answer_excel_path (str): 回答Excel文件路径
|
||||
output_path (str, optional): 输出Excel文件路径,默认为None
|
||||
"""
|
||||
self.content_source_parser = PydanticOutputParser(pydantic_object=ContentSource)
|
||||
if os.path.exists(wiki_excel_path):
|
||||
self.wiki_excel = pd.read_excel(wiki_excel_path)
|
||||
else:
|
||||
self.wiki_excel = None
|
||||
self.answer_excel = pd.read_excel(answer_excel_path)
|
||||
self.output_path = output_path or "/data/Rag2_0/data/excel/dify问答_检索内容评分.xlsx"
|
||||
|
||||
# 从环境变量中获取OpenAI的配置
|
||||
self.api_key = os.getenv("OPENAI_API_KEY")
|
||||
self.base_url = os.getenv("OPENAI_API_BASE")
|
||||
self.model_name = os.getenv("LLM_MODEL_NAME")
|
||||
|
||||
if not all([self.api_key, self.base_url, self.model_name]):
|
||||
raise ValueError("请设置 OPENAI_API_KEY, OPENAI_API_BASE, 和 LLM_MODEL_NAME 环境变量")
|
||||
|
||||
self.llm = OpenAiLLM(api_key=self.api_key, base_url=self.base_url, model=self.model_name)
|
||||
|
||||
def find_wiki_link(self, query) -> str | None:
|
||||
"""
|
||||
根据查询(对应wiki_excel中的新提问列)找出对应的词条链接
|
||||
|
||||
参数:
|
||||
query (str): 查询内容,对应wiki_excel中的新提问列
|
||||
|
||||
返回:
|
||||
str: 对应的词条链接,如果没有找到则返回None
|
||||
"""
|
||||
# 确保query不为空
|
||||
if not query or pd.isna(query):
|
||||
return None
|
||||
if self.wiki_excel is None:
|
||||
return None
|
||||
# 在"新提问"列中查找匹配的行
|
||||
matched_rows = self.wiki_excel[self.wiki_excel['新提问'] == query]
|
||||
|
||||
# 如果找到了匹配的行,返回对应的词条链接
|
||||
if not matched_rows.empty:
|
||||
return matched_rows.iloc[0]['对应词条链接']
|
||||
|
||||
# 如果没有完全匹配,尝试部分匹配
|
||||
# 去除软件名称部分(如果有)
|
||||
query_parts = query.split(',', 1)
|
||||
if len(query_parts) > 1:
|
||||
clean_query = query_parts[1].strip()
|
||||
|
||||
# 在"提问"列中查找包含清理后查询的行
|
||||
for idx, row in self.wiki_excel.iterrows():
|
||||
if pd.notna(row['提问']) and clean_query in row['提问']:
|
||||
return row['对应词条链接']
|
||||
|
||||
return None
|
||||
|
||||
def get_wiki_title(self, link) -> str | None:
|
||||
"""
|
||||
获取词条标题
|
||||
|
||||
参数:
|
||||
link (str): 词条链接
|
||||
|
||||
返回:
|
||||
str: 词条标题,如果获取失败则返回None
|
||||
"""
|
||||
try:
|
||||
if not link or pd.isna(link):
|
||||
return None
|
||||
# 移除域名部分,只保留路径
|
||||
path = link.split('/', 3)[-1]
|
||||
decoded_path = unquote(path)
|
||||
path_parts = decoded_path.split('/')
|
||||
return path_parts[-1]
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"获取词条内容失败: {str(e)}") from e
|
||||
|
||||
def calculate_score(self, answer:str, content:str) -> int:
|
||||
"""
|
||||
使用OpenAiLLM通过LLM判断answer与content之间的相关性分数
|
||||
|
||||
参数:
|
||||
answer (str): 用户问题
|
||||
content (str): 检索内容
|
||||
|
||||
返回:
|
||||
int: 相关性分数,1-10分,10代表完全相关,1代表完全不相关;-1表示评分失败
|
||||
"""
|
||||
try:
|
||||
prompt = f"""你是一个专业的信息相关性评估助手。请根据以下标准对用户query和检索内容的相关性进行1-10评分(10=完全相关,1=完全不相关),并按指定格式输出JSON结果。
|
||||
|
||||
【评分标准】
|
||||
10分:完全契合,主题/意图完全一致且涵盖所有关键信息
|
||||
8-9分:高度相关,核心要素匹配但存在少量信息缺失
|
||||
6-7分:部分相关,涉及相同主题但存在重要信息缺失
|
||||
4-5分:弱相关,仅次要信息点匹配
|
||||
1-3分:完全不相关或信息冲突
|
||||
|
||||
【评估维度】
|
||||
1. 主题一致性:核心主题/意图的匹配程度
|
||||
2. 内容覆盖度:是否涵盖query的关键要素
|
||||
3. 信息准确性:是否存在矛盾/错误信息
|
||||
4. 细节丰富度:是否提供query要求的详细信息
|
||||
|
||||
【输出格式】
|
||||
{{
|
||||
"score": 评分,
|
||||
"reason": "简明扼要的评分理由(中文)"
|
||||
}}
|
||||
|
||||
【示例】
|
||||
query: "新冠疫苗的常见副作用"
|
||||
内容: "辉瑞疫苗常见反应包括注射部位疼痛(84.1%)、疲劳(62.9%)"
|
||||
输出: {{"score":8,"reason":"主题完全匹配,涵盖主要副作用但未提及发热等常见反应"}}
|
||||
|
||||
现在评估:
|
||||
query: "{answer}"
|
||||
content: "{content}"
|
||||
"""
|
||||
|
||||
response = self.llm.invoke(user_prompt=prompt, need_retry=True)
|
||||
|
||||
# 解析JSON响应
|
||||
try:
|
||||
parsed_output = self.content_source_parser.parse(response.content)
|
||||
return parsed_output.score
|
||||
except Exception as e:
|
||||
return -1
|
||||
except Exception as e:
|
||||
return -1
|
||||
|
||||
def get_retrieve_info(self, query:str, outputs:dict) -> tuple:
|
||||
"""
|
||||
获取检索信息并计算分数
|
||||
|
||||
参数:
|
||||
query (str): 用户问题
|
||||
outputs (dict): 检索输出结果
|
||||
|
||||
返回:
|
||||
tuple: (检索内容列表, 最高分, 最低分, 平均分)
|
||||
"""
|
||||
max_score = 0
|
||||
min_score = 10
|
||||
total_score = 0
|
||||
valid_scores = 0
|
||||
retrieve_content = []
|
||||
for result in outputs["result"]:
|
||||
content = result["content"].strip()
|
||||
score = self.calculate_score(answer=query, content=content)
|
||||
if score != -1:
|
||||
max_score = max(max_score, score)
|
||||
min_score = min(min_score, score)
|
||||
total_score += score
|
||||
valid_scores += 1
|
||||
content_title = content.split("\n")[0]
|
||||
if content_title:
|
||||
retrieve_content.append(content_title + f"--得分({score}分)")
|
||||
avg_score = total_score / valid_scores if valid_scores > 0 else 0
|
||||
return retrieve_content, max_score, min_score, avg_score
|
||||
|
||||
def process(self):
|
||||
"""
|
||||
处理所有问题并评估检索内容相关性
|
||||
|
||||
遍历answer_excel中的所有问题,计算检索内容与问题的相关性分数,
|
||||
并更新Excel文件
|
||||
"""
|
||||
for idx, row in tqdm(self.answer_excel.iterrows(), total=len(self.answer_excel), desc="处理问题评分中"):
|
||||
query = row["问题"]
|
||||
link = self.find_wiki_link(query)
|
||||
answer_title = self.get_wiki_title(link)
|
||||
retrieve_content = []
|
||||
max_score = 0
|
||||
min_score = 0
|
||||
avg_score = 0 # 初始化平均分
|
||||
rewrite_query=""
|
||||
message_info = DifyTool.get_message_debug_info(appid="ccf92b97-2789-4a3f-90e0-135a869a37c5", query=query)
|
||||
for workflow_node in message_info["workflow_node_executions_info"]:
|
||||
if workflow_node["title"] == "知识检索结果后处理":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
retrieve_content, max_score, min_score, avg_score = self.get_retrieve_info(query=query, outputs=outputs)
|
||||
elif workflow_node["title"] == "问题优化结果解析":
|
||||
outputs = json.loads(workflow_node["outputs"])
|
||||
rewrite_query = outputs["optimize_query"]
|
||||
|
||||
# 更新 answer_excel 中的词条内容
|
||||
self.answer_excel.at[idx, "答案词条"] = answer_title if answer_title else ""
|
||||
self.answer_excel.at[idx, "问题改写"] = rewrite_query
|
||||
self.answer_excel.at[idx, "检索得到词条"] = "\n".join(retrieve_content) if retrieve_content else "未检索知识库"
|
||||
self.answer_excel.at[idx, "最大得分"] = max_score
|
||||
self.answer_excel.at[idx, "最小得分"] = min_score
|
||||
self.answer_excel.at[idx, "平均得分"] = avg_score
|
||||
|
||||
# 保存结果到Excel文件
|
||||
self.answer_excel.to_excel(self.output_path, index=False)
|
||||
print(f"结果已保存到 {self.output_path}")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 创建评分工具实例
|
||||
judge = RetrieveContentScoreJudge(
|
||||
wiki_excel_path="/data/Rag2_0/data/excel/400条人工标注-部分提问_软件名称明确.xlsx",
|
||||
answer_excel_path="/data/Rag2_0/data/excel/主网软件提问_回答内容评判.xlsx",
|
||||
output_path="/data/Rag2_0/data/excel/dify问答_检索内容评分.xlsx"
|
||||
)
|
||||
# 执行处理
|
||||
judge.process()
|
||||
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: merge_nouns_with_llm.py
|
||||
Description: 合并多个nouns.json中的同名专业名词,利用LLM生成唯一合并结果
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from rag2_0.intent_recognition.DataModels import Term
|
||||
import logging
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
class TermMerger:
|
||||
"""专业名词合并类,用于合并多个数据源中的同名专业名词"""
|
||||
|
||||
def __init__(self, input_dir=None, output_path=None, max_workers=3):
|
||||
"""初始化名词合并器
|
||||
|
||||
Args:
|
||||
input_dir: 包含nouns.json文件的目录路径
|
||||
output_path: 合并结果的输出文件路径
|
||||
max_workers: 线程池最大工作线程数
|
||||
"""
|
||||
self.EXTRACTED_NOUNS_DIR = input_dir
|
||||
self.OUTPUT_PATH = output_path
|
||||
self.MAX_WORKERS = max_workers
|
||||
self.terms_parser = PydanticOutputParser(pydantic_object=Term)
|
||||
self.MERGE_PROMPT = '''
|
||||
请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请:
|
||||
- 同义词(synonymous)去重合并
|
||||
- 描述(description)合并为更完整、简明的描述
|
||||
- 保持输出格式为:
|
||||
{output_format}
|
||||
原始条目:
|
||||
{items}
|
||||
'''
|
||||
# 配置LLM
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
llm_params = {"temperature": 0.3, "model": model_name}
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
self.llm = OpenAiLLM(**llm_params)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
def load_all_terms(self):
|
||||
"""读取目录下所有nouns.json,返回所有Term列表"""
|
||||
all_terms = []
|
||||
for file in glob.glob(os.path.join(self.EXTRACTED_NOUNS_DIR, '*_nouns.json')):
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
file_terms = json.load(f)
|
||||
new_terms = [{"name": term["name"].upper(), "synonymous": term["synonymous"], "description": term["description"]} for term in file_terms]
|
||||
all_terms.extend(new_terms)
|
||||
logging.info(f"加载{file},共{len(new_terms)}条")
|
||||
except Exception as e:
|
||||
logging.warning(f"读取{file}失败: {e}")
|
||||
|
||||
# 加载suffix_keywords.json文件
|
||||
suffix_keywords_path = os.path.join(os.path.dirname(os.path.dirname(self.EXTRACTED_NOUNS_DIR)), 'data', 'nouns', 'suffix_keywords.json')
|
||||
if os.path.exists(suffix_keywords_path):
|
||||
try:
|
||||
with open(suffix_keywords_path, 'r', encoding='utf-8') as f:
|
||||
suffix_terms = json.load(f)
|
||||
suffix_terms = [{"name": term["name"].upper(), "synonymous": "", "description": ""} for term in suffix_terms]
|
||||
all_terms.extend(suffix_terms)
|
||||
logging.info(f"加载{suffix_keywords_path},共{len(suffix_terms)}条")
|
||||
except Exception as e:
|
||||
logging.warning(f"读取{suffix_keywords_path}失败: {e}")
|
||||
|
||||
return all_terms
|
||||
|
||||
def group_terms_by_name(self, terms):
|
||||
"""按name聚合Term"""
|
||||
name2terms = defaultdict(list)
|
||||
for term in terms:
|
||||
name = term.get('name', '').strip()
|
||||
if name:
|
||||
name2terms[name].append(term)
|
||||
return name2terms
|
||||
|
||||
def merge_terms_with_llm(self, name, term_list):
|
||||
"""调用LLM合并同名Term,失败最多重试三次"""
|
||||
items = json.dumps(term_list, ensure_ascii=False)
|
||||
prompt = self.MERGE_PROMPT.format(name=name, items=items, output_format=self.terms_parser.get_format_instructions())
|
||||
|
||||
max_retries = 3
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
response = self.llm.invoke(prompt, False)
|
||||
parsed_output = self.terms_parser.parse(response.content)
|
||||
return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description}
|
||||
except Exception as e:
|
||||
if attempt == max_retries:
|
||||
logging.warning(f"解析LLM合并结果失败: {e}")
|
||||
return None
|
||||
else:
|
||||
time.sleep(10*attempt)
|
||||
|
||||
def process_term(self, name_terms_tuple):
|
||||
"""处理单个词条,用于线程池并行处理"""
|
||||
name, term_list = name_terms_tuple
|
||||
try:
|
||||
merged = self.merge_terms_with_llm(name, term_list)
|
||||
if merged:
|
||||
return merged
|
||||
else:
|
||||
return term_list[0]
|
||||
except Exception as e:
|
||||
logging.error(f"处理词条 {name} 时出错: {e}")
|
||||
return term_list[0]
|
||||
|
||||
def merge(self):
|
||||
"""合并所有词条的入口方法"""
|
||||
# 1. 读取所有术语
|
||||
all_terms = self.load_all_terms()
|
||||
logging.info(f"共加载{len(all_terms)}条术语")
|
||||
|
||||
# 2. 按名称聚合
|
||||
name2terms = self.group_terms_by_name(all_terms)
|
||||
logging.info(f"共{len(name2terms)}个唯一名词")
|
||||
|
||||
# 3. 使用线程池并行处理
|
||||
merged_terms = []
|
||||
items_to_process = []
|
||||
|
||||
# 先处理只有一个条目的词条(不需要合并)
|
||||
for name, term_list in name2terms.items():
|
||||
if len(term_list) == 1:
|
||||
merged_terms.append(term_list[0])
|
||||
else:
|
||||
items_to_process.append((name, term_list))
|
||||
|
||||
logging.info(f"共{len(merged_terms)}个单一条目,{len(items_to_process)}个需要合并的条目")
|
||||
|
||||
# 只对需要合并的词条使用线程池处理
|
||||
if items_to_process:
|
||||
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
||||
# 使用tqdm显示进度
|
||||
for result in tqdm(executor.map(self.process_term, items_to_process), total=len(items_to_process)):
|
||||
merged_terms.append(result)
|
||||
|
||||
# 4. 保存合并结果
|
||||
os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True)
|
||||
with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f:
|
||||
json.dump(merged_terms, f, ensure_ascii=False, indent=2)
|
||||
logging.info(f"合并后结果已保存到: {self.OUTPUT_PATH}")
|
||||
|
||||
return merged_terms
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,创建TermMerger实例并执行合并"""
|
||||
|
||||
cur_path = os.path.dirname(__file__)
|
||||
input_dir = os.path.abspath(os.path.join(cur_path, '../../data/wiki_extracted_nouns'))
|
||||
output_path = os.path.join(cur_path, "..", "..", "data", "nouns", 'merged_nouns.json')
|
||||
merger = TermMerger(input_dir=input_dir, output_path=output_path, max_workers=2)
|
||||
merger.merge()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
main()
|
||||
@@ -0,0 +1,408 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: validate_excel_data_batch.py
|
||||
Description: 使用LLM批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写是否正确
|
||||
"""
|
||||
|
||||
import os
|
||||
import pandas as pd
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
import concurrent.futures
|
||||
from tqdm import tqdm
|
||||
from dotenv import load_dotenv
|
||||
from langchain_openai import ChatOpenAI
|
||||
from rag2_0.intent_recognition.PromptTemplates import classification
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
|
||||
class ExcelDataValidator:
|
||||
"""Excel数据验证类,用于批量验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写"""
|
||||
|
||||
def __init__(self, input_file=None, output_file=None, workers=4, batch_size=10):
|
||||
"""
|
||||
初始化验证器
|
||||
|
||||
Args:
|
||||
input_file: 输入Excel文件路径
|
||||
output_file: 输出结果Excel文件路径
|
||||
workers: 并行工作线程数
|
||||
batch_size: 每批处理的行数
|
||||
"""
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
self.input_file = input_file
|
||||
self.output_file = output_file
|
||||
self.workers = workers
|
||||
self.batch_size = batch_size
|
||||
self.df = None
|
||||
|
||||
# 设置日志
|
||||
self.setup_logging()
|
||||
|
||||
def setup_logging(self):
|
||||
"""配置日志输出"""
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler()
|
||||
]
|
||||
)
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
|
||||
def load_data_from_excel(self, file_path=None):
|
||||
"""
|
||||
从Excel文件中读取数据
|
||||
|
||||
Args:
|
||||
file_path: Excel文件路径,如不提供则使用初始化时的路径
|
||||
|
||||
Returns:
|
||||
DataFrame对象
|
||||
"""
|
||||
file_path = file_path or self.input_file
|
||||
if not file_path:
|
||||
logging.error("未指定输入文件路径")
|
||||
return None
|
||||
|
||||
try:
|
||||
df = pd.read_excel(file_path)
|
||||
required_columns = ["提问", "问题拆解", "一级分类", "二级分类", "问题改写", "检索的关键词"]
|
||||
for col in required_columns:
|
||||
if col not in df.columns:
|
||||
logging.error(f"缺少必要的列: {col}")
|
||||
return None
|
||||
logging.info(f"成功从{file_path}读取了{len(df)}条数据")
|
||||
self.df = df
|
||||
return df
|
||||
except Exception as e:
|
||||
logging.error(f"读取Excel文件时出错: {e}")
|
||||
return None
|
||||
|
||||
def validate_classification(self, llm, query, vertical_class, sub_class):
|
||||
"""
|
||||
验证问题分类是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
vertical_class: 一级分类
|
||||
sub_class: 二级分类
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
"""
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用,也可能涉及电力造价专业知识。我对用户问题进行了分类,请评估以下问题分类是否正确。
|
||||
|
||||
我目前总共有以下分类:
|
||||
{classification}
|
||||
|
||||
问题的分类情况如下:
|
||||
原始问题: {query}
|
||||
一级分类: {vertical_class}
|
||||
二级分类: {sub_class}
|
||||
|
||||
请从专业角度分析这个分类是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
except Exception as e:
|
||||
logging.warning(f"验证问题分类时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
|
||||
def validate_query_keys(self, llm, query, query_keys):
|
||||
"""
|
||||
验证问题拆解是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
query_keys: 问题拆解
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
"""
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了拆解,请评估以下问题拆解是否正确。
|
||||
|
||||
原始问题: {query}
|
||||
问题拆解: {query_keys}
|
||||
|
||||
问题拆解应该准确提取原始问题中的关键词和信息。请分析这个拆解是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
except Exception as e:
|
||||
logging.warning(f"验证问题拆解时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
|
||||
def validate_keywords(self, llm, query, query_keys, keywords_str):
|
||||
"""
|
||||
验证检索关键词是否准确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
query_keys: 问题拆解
|
||||
keywords_str: 检索关键词(JSON字符串)
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
"""
|
||||
# 解析关键词JSON
|
||||
try:
|
||||
if isinstance(keywords_str, str) and keywords_str.strip():
|
||||
keywords = json.loads(keywords_str)
|
||||
else:
|
||||
keywords = []
|
||||
except:
|
||||
keywords = keywords_str
|
||||
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。通过问题检索出了一些关键词,请评估这些关键词是否准确,是否与问题相关
|
||||
原始问题: {query}
|
||||
问题拆解: {query_keys}
|
||||
检索关键词: {keywords}
|
||||
|
||||
检索关键词应该准确反映问题中需要检索的关键概念和术语。请分析这些关键词是否准确、完整。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
except Exception as e:
|
||||
logging.warning(f"验证检索关键词时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
|
||||
def validate_rewrite(self, llm, query, rewrite):
|
||||
"""
|
||||
验证问题改写是否正确
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
query: 原始问题
|
||||
rewrite: 问题改写
|
||||
|
||||
Returns:
|
||||
(bool, str): 是否正确,错误原因(如果有)
|
||||
"""
|
||||
prompt = f"""
|
||||
背景:用户正在使用电力造价软件,提出的问题可能涉及电力造价软件的使用帮助,也可能涉及电力造价专业知识。我对用户问题进行了改写,请评估以下问题改写是否正确。
|
||||
|
||||
原始问题: {query}
|
||||
问题改写: {rewrite}
|
||||
|
||||
问题改写应该保持原问题的核心意图,同时使表达更加清晰、完整。请分析改写是否准确。只需返回"正确"或"错误:原因",不需要其他解释。"""
|
||||
|
||||
try:
|
||||
response = llm.invoke(prompt)
|
||||
result = response.content.strip()
|
||||
|
||||
if result.startswith("正确"):
|
||||
return True, ""
|
||||
else:
|
||||
error_reason = result.replace("错误:", "").strip() if "错误:" in result else result
|
||||
return False, error_reason
|
||||
except Exception as e:
|
||||
logging.warning(f"验证问题改写时出错: {e}")
|
||||
return False, f"验证过程出错: {str(e)}"
|
||||
|
||||
def validate_row(self, llm, row_data):
|
||||
"""
|
||||
按顺序验证一行数据中的各个环节
|
||||
|
||||
Args:
|
||||
llm: LLM模型
|
||||
row_data: (index, row)元组
|
||||
|
||||
Returns:
|
||||
(index, is_all_correct, error_phase, error_reason): 行索引,是否全部正确,错误环节,错误原因
|
||||
"""
|
||||
index, row = row_data
|
||||
query = row["提问"]
|
||||
query_keys = row["问题拆解"]
|
||||
vertical_class = row["一级分类"]
|
||||
sub_class = row["二级分类"]
|
||||
rewrite = row["问题改写"]
|
||||
keywords_str = row["检索的关键词"]
|
||||
|
||||
try:
|
||||
# 1. 验证问题分类
|
||||
is_correct, error_reason = self.validate_classification(llm, query, vertical_class, sub_class)
|
||||
if not is_correct:
|
||||
return index, False, "问题分类", error_reason
|
||||
|
||||
# 2. 验证问题拆解
|
||||
is_correct, error_reason = self.validate_query_keys(llm, query, query_keys)
|
||||
if not is_correct:
|
||||
return index, False, "问题拆解", error_reason
|
||||
|
||||
# 3. 验证检索关键词
|
||||
is_correct, error_reason = self.validate_keywords(llm, query, query_keys, keywords_str)
|
||||
if not is_correct:
|
||||
return index, False, "关键词检索", error_reason
|
||||
|
||||
# 4. 验证问题改写
|
||||
is_correct, error_reason = self.validate_rewrite(llm, query, rewrite)
|
||||
if not is_correct:
|
||||
return index, False, "问题改写", error_reason
|
||||
|
||||
return index, True, "", ""
|
||||
except Exception as e:
|
||||
error_msg = f"处理行 {index} 时发生错误: {str(e)}"
|
||||
logging.error(error_msg)
|
||||
return index, False, "处理错误", error_msg
|
||||
|
||||
def process_batch(self, llm, batch_data):
|
||||
"""处理一批数据"""
|
||||
results = []
|
||||
for row_data in batch_data:
|
||||
results.append(self.validate_row(llm, row_data))
|
||||
return results
|
||||
|
||||
def create_llm_instances(self, count):
|
||||
"""创建多个LLM实例"""
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
|
||||
llm_params = {"temperature": 0.7, "model": model_name}
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
|
||||
return [OpenAiLLM(**llm_params) for _ in range(count)]
|
||||
|
||||
def validate(self, input_file=None, output_file=None, workers=None, batch_size=None):
|
||||
"""
|
||||
执行验证过程
|
||||
|
||||
Args:
|
||||
input_file: 输入Excel文件路径
|
||||
output_file: 输出结果Excel文件路径
|
||||
workers: 并行工作线程数
|
||||
batch_size: 每批处理的行数
|
||||
|
||||
Returns:
|
||||
验证后的DataFrame
|
||||
"""
|
||||
input_file = input_file or self.input_file
|
||||
output_file = output_file or self.output_file
|
||||
workers = workers or self.workers
|
||||
batch_size = batch_size or self.batch_size
|
||||
|
||||
# 读取数据
|
||||
df = self.load_data_from_excel(input_file)
|
||||
if df is None:
|
||||
return None
|
||||
|
||||
# 添加验证结果列
|
||||
df["验证结果"] = ""
|
||||
df["错误环节"] = ""
|
||||
df["错误原因"] = ""
|
||||
|
||||
# 准备数据批次
|
||||
all_rows = list(df.iterrows())
|
||||
batches = [all_rows[i:i+batch_size] for i in range(0, len(all_rows), batch_size)]
|
||||
|
||||
# 创建多个LLM实例
|
||||
llm_instances = self.create_llm_instances(min(workers, len(batches)))
|
||||
|
||||
# 使用线程池处理数据
|
||||
all_results = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor:
|
||||
# 为每个批次分配一个LLM实例
|
||||
future_to_batch = {
|
||||
executor.submit(self.process_batch, llm_instances[i % len(llm_instances)], batch):
|
||||
i for i, batch in enumerate(batches)
|
||||
}
|
||||
|
||||
# 使用tqdm显示进度条
|
||||
for future in tqdm(concurrent.futures.as_completed(future_to_batch), total=len(batches), desc="批次处理进度"):
|
||||
batch_results = future.result()
|
||||
all_results.extend(batch_results)
|
||||
|
||||
# 按行索引排序结果,确保与原始数据顺序一致
|
||||
all_results.sort(key=lambda x: x[0])
|
||||
|
||||
# 将结果填充到DataFrame
|
||||
for index, is_correct, error_phase, error_reason in all_results:
|
||||
df.at[index, "验证结果"] = "通过" if is_correct else "不通过"
|
||||
df.at[index, "错误环节"] = error_phase
|
||||
df.at[index, "错误原因"] = error_reason
|
||||
|
||||
# 保存结果
|
||||
if output_file is None:
|
||||
output_file = os.path.join(
|
||||
os.path.dirname(input_file),
|
||||
f"validated_{os.path.basename(input_file)}"
|
||||
)
|
||||
df.to_excel(output_file, index=False)
|
||||
logging.info(f"验证完成,结果已保存至: {output_file}")
|
||||
|
||||
# 输出统计信息
|
||||
self.print_statistics(df)
|
||||
|
||||
return df
|
||||
|
||||
def print_statistics(self, df):
|
||||
"""打印统计信息"""
|
||||
total = len(df)
|
||||
passed = len(df[df["验证结果"] == "通过"])
|
||||
error_stats = df[df["验证结果"] == "不通过"]["错误环节"].value_counts()
|
||||
|
||||
logging.info(f"统计信息: 总计 {total} 条, 通过 {passed} 条, 通过率 {passed/total*100:.2f}%")
|
||||
logging.info("错误环节统计:")
|
||||
for phase, count in error_stats.items():
|
||||
logging.info(f"- {phase}: {count} 条")
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数"""
|
||||
# 解析命令行参数
|
||||
input_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "问题分类重写结果")
|
||||
output_excel = os.path.join(os.path.dirname(__file__), "..", "..", "data", "excel", "自动验证_问题分类重写结果.xlsx")
|
||||
parser = argparse.ArgumentParser(description="验证Excel数据中的问题分类、问题拆解、检索关键词和问题改写")
|
||||
parser.add_argument("--input", "-i", type=str, required=True, help="输入Excel文件路径", default=input_excel)
|
||||
parser.add_argument("--output", "-o", type=str, help="输出结果Excel文件路径", default=output_excel)
|
||||
parser.add_argument("--workers", "-w", type=int, default=2, help="并行工作线程数")
|
||||
parser.add_argument("--batch-size", "-b", type=int, default=5, help="每批处理的行数")
|
||||
args = parser.parse_args()
|
||||
|
||||
# 创建验证器实例并执行验证
|
||||
validator = ExcelDataValidator(
|
||||
input_file=args.input,
|
||||
output_file=args.output,
|
||||
workers=args.workers,
|
||||
batch_size=args.batch_size
|
||||
)
|
||||
validator.validate()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,47 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: vectorize_save_noun.py
|
||||
Date: 2025-05-15
|
||||
Description: 专业名词向量化和保存的示例程序
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.intent_recognition import ProfessionalNounVectorizer
|
||||
import logging
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数:创建索引并保存
|
||||
"""
|
||||
# 指定文件路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
output_dir = os.path.join(current_dir, "..", "..", "data", "nouns")
|
||||
|
||||
# 创建向量化器并指定路径
|
||||
noun_vectorizer = ProfessionalNounVectorizer(
|
||||
output_dir=output_dir
|
||||
)
|
||||
file_paths = [
|
||||
os.path.join(current_dir, "..", "..", "data/nouns/merged_nouns.json"),
|
||||
]
|
||||
# 执行向量化和保存(一步完成)
|
||||
success = noun_vectorizer.vectorize_files_and_save(file_paths)
|
||||
if success:
|
||||
logging.info("✓ 索引创建和保存成功")
|
||||
logging.info(f" 索引保存路径: {os.path.join(output_dir, 'professional_nouns_index')}")
|
||||
else:
|
||||
logging.error("✗ 索引创建失败")
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 配置日志输出到控制台
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format='%(message)s'
|
||||
)
|
||||
main()
|
||||
Reference in New Issue
Block a user