189 lines
6.7 KiB
Python
189 lines
6.7 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
File: intent_recognition_example.py
|
|
Date: 2025-05-14
|
|
Description: 意图识别和问题改写示例
|
|
"""
|
|
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from rag2_0.intent_recognition import IntentRecognizer
|
|
import pandas as pd
|
|
import logging
|
|
import json
|
|
import concurrent.futures
|
|
from tqdm import tqdm
|
|
import time
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
|
|
# 读取Excel文件中的提问数据
|
|
def load_questions_from_excel(file_path=None):
|
|
"""
|
|
从Excel文件中读取提问数据
|
|
|
|
Args:
|
|
file_path: Excel文件路径,如果为None则使用默认路径
|
|
|
|
Returns:
|
|
提问列表
|
|
"""
|
|
|
|
try:
|
|
# 读取Excel文件的第一列数据
|
|
df = pd.read_excel(file_path)
|
|
questions = df.iloc[:, 0].tolist() # 获取第一列数据
|
|
logging.info(f"成功从{file_path}读取了{len(questions)}条提问")
|
|
return questions
|
|
except Exception as e:
|
|
logging.error(f"读取Excel文件时出错: {e}")
|
|
return []
|
|
|
|
def process_query(recognizer, query):
|
|
"""
|
|
处理单个查询,支持重试机制
|
|
|
|
Args:
|
|
recognizer: 意图识别器实例
|
|
query: 查询字符串
|
|
|
|
Returns:
|
|
处理结果字典
|
|
"""
|
|
max_retries = 3
|
|
retry_count = 0
|
|
|
|
while retry_count <= max_retries:
|
|
try:
|
|
# 如果是重试,添加重试信息到日志
|
|
classification, keywords, rewrite, query_keys = recognizer.process_query(query)
|
|
|
|
# 将keywords对象转换为字符串
|
|
keywords_str = ""
|
|
if keywords and keywords.terms:
|
|
term_details = []
|
|
for term in keywords.terms:
|
|
term_info = {
|
|
"名称": term.name,
|
|
"同义词": ";".join(term.synonymous) if term.synonymous else "",
|
|
"描述": term.description
|
|
}
|
|
term_details.append(term_info)
|
|
|
|
# 将term_details转换为JSON字符串,确保中文正确显示
|
|
keywords_str = json.dumps(term_details, ensure_ascii=False, indent=2)
|
|
|
|
# 处理成功,返回结果
|
|
return {
|
|
"提问": query,
|
|
"问题拆解": query_keys,
|
|
"一级分类": classification.vertical_classification,
|
|
"二级分类": classification.sub_classification,
|
|
"问题改写": rewrite.rewrite,
|
|
"检索的关键词": keywords_str
|
|
}
|
|
|
|
except Exception as e:
|
|
retry_count += 1
|
|
|
|
# 如果已经重试了最大次数,则记录错误并返回错误结果
|
|
if retry_count > max_retries:
|
|
logging.error(f"处理问题 '{query}' 时出错: {e.__class__}{e}")
|
|
return {
|
|
"提问": query,
|
|
"一级分类": "处理出错",
|
|
"二级分类": "处理出错",
|
|
"问题改写": "处理出错",
|
|
"检索的关键词": f"重试 {max_retries} 次后失败: {str(e)}"
|
|
}
|
|
else:
|
|
# 可以在这里添加延迟,避免过快重试
|
|
time.sleep(10 * retry_count)
|
|
|
|
examples_query = """下载软件在哪下载?"""
|
|
|
|
def main():
|
|
"""
|
|
意图识别和问题改写示例
|
|
"""
|
|
|
|
# 从环境变量中获取配置
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
base_url = os.getenv("OPENAI_API_BASE")
|
|
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
|
|
|
# 初始化意图识别器
|
|
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
|
|
|
# 读取提问数据
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
data_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条提问数据.xlsx")
|
|
examples = load_questions_from_excel(data_file)
|
|
# examples = examples_query.split("\n")
|
|
max_workers = 20
|
|
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
|
|
|
|
# 创建一个与输入顺序相同的结果列表
|
|
results = [None] * len(examples)
|
|
|
|
# 使用线程池进行并发处理
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
# 提交所有任务并记录它们的索引
|
|
future_to_index = {}
|
|
for idx, query in enumerate(examples):
|
|
future = executor.submit(process_query, recognizer, query)
|
|
future_to_index[future] = idx
|
|
|
|
# 使用tqdm显示进度条
|
|
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(examples), desc="处理进度"):
|
|
idx = future_to_index[future]
|
|
result = future.result()
|
|
# 将结果放在与输入相同的位置
|
|
results[idx] = result
|
|
|
|
# 将结果保存到Excel文件
|
|
results_df = pd.DataFrame(results)
|
|
|
|
output_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条提问数据_重写结果.xlsx")
|
|
|
|
# 使用ExcelWriter设置格式
|
|
with pd.ExcelWriter(output_file, engine='xlsxwriter') as writer:
|
|
results_df.to_excel(writer, index=False, sheet_name='Sheet1')
|
|
|
|
# 获取工作簿和工作表对象
|
|
workbook = writer.book
|
|
worksheet = writer.sheets['Sheet1']
|
|
|
|
# 设置列宽(单位:像素)
|
|
# 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位)
|
|
worksheet.set_column('A:A', 60) # 提问列 60个Excel单位
|
|
worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位
|
|
worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位
|
|
worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位
|
|
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
|
|
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
|
|
|
|
# 设置所有行高为20磅
|
|
for i in range(len(results_df) + 1): # +1 是为了包括表头
|
|
worksheet.set_row(i, 20)
|
|
|
|
logging.info(f"处理完成,结果已保存至: {output_file}")
|
|
|
|
def setup_logging():
|
|
# 配置日志输出到控制台
|
|
logging.basicConfig(
|
|
level=logging.INFO,
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
|
handlers=[
|
|
logging.StreamHandler() # 添加控制台处理器
|
|
]
|
|
)
|
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
|
|
|
if __name__ == "__main__":
|
|
setup_logging()
|
|
logging.info("意图识别示例程序开始运行...")
|
|
main() |