优化意图识别示例,新增命令行参数解析功能,支持输入输出文件路径和调试模式,增强代码可读性和灵活性。同时更新Dify工具,调整检索信息获取逻辑,确保重排得分信息的正确传递。

This commit is contained in:
2025-06-18 14:53:24 +08:00
parent 08a7a5812a
commit 139d0cffef
8 changed files with 229 additions and 88 deletions
+48 -12
View File
@@ -17,6 +17,7 @@ import concurrent.futures
from tqdm import tqdm
import time
import sys
import argparse
from typing import List, Dict
# 加载环境变量
load_dotenv()
@@ -176,6 +177,7 @@ def save_results_to_excel(results, output_file, is_final=False):
# 示例查询
examples_query = """那储能软件如何操作"""
examples_query = """博微软件如何新建工程啊"""
conversation_context=""
chat_history=[
{
@@ -199,34 +201,68 @@ previous_slots={
"software_version": None,
"operation_steps": None
}
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(description='意图识别和问题改写工具')
# 添加数据文件路径参数
parser.add_argument('--input', '-i', type=str,
help='输入Excel文件路径,包含待处理的提问数据(第一列)')
parser.add_argument('--output', '-o', type=str,
help='输出Excel文件路径,用于保存处理结果')
# 添加LLM相关参数
parser.add_argument('--model', '-m', type=str,
help='LLM模型名称,默认使用环境变量中的配置')
parser.add_argument('--api_base', '-a', type=str,
help='API基础URL,默认使用环境变量中的配置')
# 添加处理相关参数
parser.add_argument('--max_workers', '-w', type=int, default=20,
help='并发处理的最大线程数,默认为20')
parser.add_argument('--debug', '-d', action='store_true',
help='启用调试模式,使用示例查询而非从文件读取')
parser.add_argument('--query', '-q', type=str,
help='在调试模式下使用的查询字符串')
return parser.parse_args()
def main():
"""
意图识别和问题改写示例
"""
# 从环境变量中获取配置
# 解析命令行参数
args = parse_arguments()
# 从环境变量中获取配置,命令行参数优先
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
base_url = args.api_base if args.api_base else os.getenv("OPENAI_API_BASE")
model_name = args.model if args.model else os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
# 初始化意图识别器
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
# 读取提问数据
current_dir = os.path.dirname(os.path.abspath(__file__))
data_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条点踩数据测试.xlsx")
output_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条点踩数据测试_槽位填充结果.xlsx")
data_file = args.input if args.input else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_提问明确.xlsx")
output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_槽位(分类)填充结果.xlsx")
# 检测是否为调试模式
is_debug = args.debug or (hasattr(sys, 'gettrace') and sys.gettrace() is not None)
# 检测是否为调试模式,调试模式下使用examples_query,否则从Excel读取
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
# is_debug = False
if is_debug:
examples = examples_query.strip().split("\n")
# 如果提供了查询参数,使用它;否则使用默认示例
if args.query:
examples = [args.query]
else:
examples = examples_query.strip().split("\n")
else:
examples = load_questions_from_excel(data_file)
if not is_debug:
max_workers = 20 # 减少并发数以避免API限制
max_workers = args.max_workers
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
# 创建一个与输入顺序相同的结果列表
@@ -262,8 +298,8 @@ def main():
for idx, query in enumerate(examples):
if query.strip() == "":
continue
process_query(recognizer, query, conversation_context, chat_history, previous_slots)
# print(json.dumps(process_query(recognizer, query), ensure_ascii=False, indent=2))
# process_query(recognizer, query, conversation_context, chat_history, previous_slots)
print(json.dumps(process_query(recognizer, query), ensure_ascii=False, indent=2))
def setup_logging():
# 配置日志输出到控制台