Files
QueryRewrite/rag2_0/demo/intent_recognition_example.py
T

319 lines
13 KiB
Python

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: intent_recognition_example.py
Date: 2025-05-14
Description: 意图识别和问题改写示例
"""
import os
from dotenv import load_dotenv
from regex import F
from rag2_0.intent_recognition import IntentRecognizer
import pandas as pd
import logging
import json
import concurrent.futures
from tqdm import tqdm
import time
import sys
import argparse
from typing import List, Dict
# 加载环境变量
load_dotenv()
# 读取Excel文件中的提问数据
def load_questions_from_excel(file_path=None):
"""
从Excel文件中读取提问数据
Args:
file_path: Excel文件路径,如果为None则使用默认路径
Returns:
提问列表
"""
try:
# 读取Excel文件的第一列数据
df = pd.read_excel(file_path)
questions = df.iloc[:, 0].tolist() # 获取第一列数据
logging.info(f"成功从{file_path}读取了{len(questions)}条提问")
return questions
except Exception as e:
logging.error(f"读取Excel文件时出错: {e}")
return []
def process_query(recognizer: IntentRecognizer, query: str, conversation_context: str = "", chat_history: List[Dict[str, str]] = None, previous_slots: Dict[str, str] = None):
"""
处理单个查询,支持重试机制,并包含槽位填充
Args:
recognizer: 意图识别器实例
query: 查询字符串
Returns:
处理结果字典
"""
max_retries = 3
retry_count = 0
while retry_count <= max_retries:
try:
# 使用新的process_query_with_slots方法处理查询
# result = recognizer.process_query_with_slots(query)
result = recognizer.process_query(query, conversation_context=conversation_context, chat_history=chat_history, previous_slots=previous_slots)
# 提取分类信息
classification = result["classification"]
# 提取关键词信息
keywords = result["keywords"]
keywords_str = ""
if keywords and keywords.get("terms"):
term_details = []
for term in keywords["terms"]:
term_info = {
"名称": term["name"],
"同义词": ";".join(term["synonymous"]) if term["synonymous"] else "",
"描述": term["description"]
}
term_details.append(term_info)
# 将term_details转换为JSON字符串,确保中文正确显示
keywords_str = json.dumps(term_details, ensure_ascii=False, indent=2)
# 提取槽位填充信息
slot_filling = result.get("slot_filling", {})
slot_filling_str = ""
if slot_filling and "filled_data" in slot_filling:
# 格式化槽位填充结果
slot_filling_str = json.dumps({
"是否完整": slot_filling.get("is_complete", False),
"缺失槽位": slot_filling.get("missing_slots", {}),
"填充数据": slot_filling.get("filled_data", {})
}, ensure_ascii=False, indent=2)
# 处理成功,返回结果
return {
"提问": query,
"问题拆解": result["query_keys"],
"一级分类": classification["vertical_classification"],
"二级分类": classification["sub_classification"],
"问题改写": result["rewrite"]["rewrite"],
"检索的关键词": keywords_str,
"槽位填充": slot_filling_str
}
except Exception as e:
retry_count += 1
# 如果已经重试了最大次数,则记录错误并返回错误结果
if retry_count > max_retries:
logging.error(f"处理问题 '{query}' 时出错: {e.__class__}{e}")
return {
"提问": query,
"一级分类": "处理出错",
"二级分类": "处理出错",
"问题改写": "处理出错",
"检索的关键词": f"重试 {max_retries} 次后失败: {str(e)}",
"槽位填充": "处理出错"
}
else:
# 可以在这里添加延迟,避免过快重试
time.sleep(10)
def save_results_to_excel(results, output_file, is_final=False):
"""
将结果保存到Excel文件
Args:
results: 结果列表
output_file: 输出文件路径
is_final: 是否为最终保存,如果是则使用完整文件名,否则添加临时标记
Returns:
None
"""
# 过滤掉None值
valid_results = [r for r in results if r is not None]
if not valid_results:
logging.warning("没有有效结果可保存")
return
# 创建DataFrame
results_df = pd.DataFrame(valid_results)
# 根据是否为最终保存确定文件名
if not is_final:
file_name, file_ext = os.path.splitext(output_file)
temp_output_file = f"{file_name}_temp{file_ext}"
else:
temp_output_file = output_file
# 使用ExcelWriter设置格式
with pd.ExcelWriter(temp_output_file, engine='xlsxwriter') as writer:
results_df.to_excel(writer, index=False, sheet_name='Sheet1')
# 获取工作簿和工作表对象
workbook = writer.book
worksheet = writer.sheets['Sheet1']
# 设置列宽(单位:像素)
# 定义列宽(厘米转为Excel单位,1cm约等于4.7个Excel单位)
worksheet.set_column('A:A', 60) # 提问列 60个Excel单位
worksheet.set_column('B:B', 20) # 问题拆解 20个Excel单位
worksheet.set_column('C:C', 20) # 一级分类 20个Excel单位
worksheet.set_column('D:D', 20) # 二级分类 20个Excel单位
worksheet.set_column('E:E', 60) # 问题改写 60个Excel单位
worksheet.set_column('F:F', 60) # 检索到的关键词 60个Excel单位
worksheet.set_column('G:G', 80) # 槽位填充 80个Excel单位
# 设置所有行高为20磅
for i in range(len(results_df) + 1): # +1 是为了包括表头
worksheet.set_row(i, 20)
logging.info(f"已保存{len(valid_results)}条结果至: {temp_output_file}")
# 示例查询
examples_query = """那储能软件如何操作"""
examples_query = """博微软件如何新建工程啊"""
conversation_context=""
chat_history=[
{
"user": "怎么新建工程啊"
},
{
"assistant": "您好!为了更准确地帮助您解决问题,请提供以下信息:\n\n- **软件名称**:请确认您使用的软件名称,可选值如下:\n - 配网工程计价通D3软件\n - 新型储能电站建设计价通C1软件\n - 西藏电力工程计价通Z1软件\n - 技改检修工程计价通T1软件\n - 技改检修清单计价通T1软件\n - 主网电力建设计价通软件\n\n请补充以上信息,以便我们能够更好地指导您如何新建工程。"
},
{
"user": "配网工程计价通D3软件"
},
{
"assistant": """好的,针对您使用**配网工程计价通D3软件**新建工程的问题,以下是具体的操作步骤:\n\n## 新建工程\n\n### 功能入口\n【主页】界面——"新建"按钮\n\n### 操作步骤\n1. 在"新建窗口"选择对应工程模板。\n2. 设置工程名称。\n3. 确定后根据新建向导完善"电压等级"、"地区类型"等参数。\n4. 点击"确定"即可完成新建工程。\n\n![进入新建工程窗口](https://172.20.0.145/files/362491a3-69c3-45b0-a037-b5f96b91bf73/image-preview)\n\n希望这些步骤能帮助您顺利完成新建工程。如果还有其他问题,欢迎随时提问!\n"""
}
]
previous_slots={
"software_name": "配网工程计价通D3软件",
"function_name": "新建工程",
"operation": "如何新建工程",
"project_type": None,
"software_version": None,
"operation_steps": None
}
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='意图识别和问题改写工具')
# 添加数据文件路径参数
parser.add_argument('--input', '-i', type=str,
help='输入Excel文件路径,包含待处理的提问数据(第一列)')
parser.add_argument('--output', '-o', type=str,
help='输出Excel文件路径,用于保存处理结果')
# 添加LLM相关参数
parser.add_argument('--model', '-m', type=str,
help='LLM模型名称,默认使用环境变量中的配置')
parser.add_argument('--api_base', '-a', type=str,
help='API基础URL,默认使用环境变量中的配置')
# 添加处理相关参数
parser.add_argument('--max_workers', '-w', type=int, default=20,
help='并发处理的最大线程数,默认为20')
parser.add_argument('--debug', '-d', action='store_true',
help='启用调试模式,使用示例查询而非从文件读取')
parser.add_argument('--query', '-q', type=str,
help='在调试模式下使用的查询字符串')
return parser.parse_args()
def main():
"""
意图识别和问题改写示例
"""
# 解析命令行参数
args = parse_arguments()
# 从环境变量中获取配置,命令行参数优先
api_key = os.getenv("OPENAI_API_KEY")
base_url = args.api_base if args.api_base else os.getenv("OPENAI_API_BASE")
model_name = args.model if args.model else os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
# 初始化意图识别器
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
# 读取提问数据
current_dir = os.path.dirname(os.path.abspath(__file__))
data_file = args.input if args.input else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_提问明确.xlsx")
output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_槽位(分类)填充结果.xlsx")
# 检测是否为调试模式
is_debug = args.debug or (hasattr(sys, 'gettrace') and sys.gettrace() is not None)
if is_debug:
# 如果提供了查询参数,使用它;否则使用默认示例
if args.query:
examples = [args.query]
else:
examples = examples_query.strip().split("\n")
else:
examples = load_questions_from_excel(data_file)
if not is_debug:
max_workers = args.max_workers
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
# 创建一个与输入顺序相同的结果列表
results = [None] * len(examples)
batch_size = 100 # 每100条保存一次
# 使用线程池进行并发处理
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
# 提交所有任务并记录它们的索引
future_to_index = {}
for idx, query in enumerate(examples):
future = executor.submit(process_query, recognizer, query)
future_to_index[future] = idx
# 使用tqdm显示进度条
completed = 0
for future in tqdm(concurrent.futures.as_completed(future_to_index), total=len(examples), desc="处理进度"):
idx = future_to_index[future]
result = future.result()
# 将结果放在与输入相同的位置
results[idx] = result
completed += 1
# 每处理batch_size条数据保存一次
# if completed % batch_size == 0:
# logging.info(f"已完成 {completed}/{len(examples)} 条,保存中间结果...")
# save_results_to_excel(results, output_file, is_final=False)
# 处理完所有数据后,保存最终结果
save_results_to_excel(results, output_file, is_final=True)
logging.info(f"所有处理完成,最终结果已保存至: {output_file}")
else:
for idx, query in enumerate(examples):
if query.strip() == "":
continue
# process_query(recognizer, query, conversation_context, chat_history, previous_slots)
print(json.dumps(process_query(recognizer, query), ensure_ascii=False, indent=2))
def setup_logging():
# 配置日志输出到控制台
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.StreamHandler() # 添加控制台处理器
]
)
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
if __name__ == "__main__":
setup_logging()
logging.info("意图识别示例程序开始运行...")
main()