优化意图识别示例,新增命令行参数解析功能,支持输入输出文件路径和调试模式,增强代码可读性和灵活性。同时更新Dify工具,调整检索信息获取逻辑,确保重排得分信息的正确传递。
This commit is contained in:
@@ -17,6 +17,7 @@ import concurrent.futures
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
import sys
|
||||
import argparse
|
||||
from typing import List, Dict
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
@@ -176,6 +177,7 @@ def save_results_to_excel(results, output_file, is_final=False):
|
||||
|
||||
# 示例查询
|
||||
examples_query = """那储能软件如何操作"""
|
||||
examples_query = """博微软件如何新建工程啊"""
|
||||
conversation_context=""
|
||||
chat_history=[
|
||||
{
|
||||
@@ -199,34 +201,68 @@ previous_slots={
|
||||
"software_version": None,
|
||||
"operation_steps": None
|
||||
}
|
||||
|
||||
def parse_arguments():
|
||||
"""解析命令行参数"""
|
||||
parser = argparse.ArgumentParser(description='意图识别和问题改写工具')
|
||||
|
||||
# 添加数据文件路径参数
|
||||
parser.add_argument('--input', '-i', type=str,
|
||||
help='输入Excel文件路径,包含待处理的提问数据(第一列)')
|
||||
parser.add_argument('--output', '-o', type=str,
|
||||
help='输出Excel文件路径,用于保存处理结果')
|
||||
|
||||
# 添加LLM相关参数
|
||||
parser.add_argument('--model', '-m', type=str,
|
||||
help='LLM模型名称,默认使用环境变量中的配置')
|
||||
parser.add_argument('--api_base', '-a', type=str,
|
||||
help='API基础URL,默认使用环境变量中的配置')
|
||||
|
||||
# 添加处理相关参数
|
||||
parser.add_argument('--max_workers', '-w', type=int, default=20,
|
||||
help='并发处理的最大线程数,默认为20')
|
||||
parser.add_argument('--debug', '-d', action='store_true',
|
||||
help='启用调试模式,使用示例查询而非从文件读取')
|
||||
parser.add_argument('--query', '-q', type=str,
|
||||
help='在调试模式下使用的查询字符串')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
"""
|
||||
意图识别和问题改写示例
|
||||
"""
|
||||
|
||||
# 从环境变量中获取配置
|
||||
# 解析命令行参数
|
||||
args = parse_arguments()
|
||||
|
||||
# 从环境变量中获取配置,命令行参数优先
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
base_url = args.api_base if args.api_base else os.getenv("OPENAI_API_BASE")
|
||||
model_name = args.model if args.model else os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
|
||||
# 初始化意图识别器
|
||||
recognizer = IntentRecognizer(api_key=api_key, base_url=base_url, model_name=model_name)
|
||||
|
||||
# 读取提问数据
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
data_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条点踩数据测试.xlsx")
|
||||
output_file = os.path.join(current_dir, "..", "..", "data", "excel", "200条点踩数据测试_槽位填充结果.xlsx")
|
||||
data_file = args.input if args.input else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_提问明确.xlsx")
|
||||
output_file = args.output if args.output else os.path.join(current_dir, "..", "..", "data", "excel", "历史提问数据(dislike)_槽位(分类)填充结果.xlsx")
|
||||
|
||||
# 检测是否为调试模式
|
||||
is_debug = args.debug or (hasattr(sys, 'gettrace') and sys.gettrace() is not None)
|
||||
|
||||
# 检测是否为调试模式,调试模式下使用examples_query,否则从Excel读取
|
||||
is_debug = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||
# is_debug = False
|
||||
if is_debug:
|
||||
examples = examples_query.strip().split("\n")
|
||||
# 如果提供了查询参数,使用它;否则使用默认示例
|
||||
if args.query:
|
||||
examples = [args.query]
|
||||
else:
|
||||
examples = examples_query.strip().split("\n")
|
||||
else:
|
||||
examples = load_questions_from_excel(data_file)
|
||||
|
||||
if not is_debug:
|
||||
max_workers = 20 # 减少并发数以避免API限制
|
||||
max_workers = args.max_workers
|
||||
logging.info(f"共有 {len(examples)} 个问题需要处理,使用 {max_workers} 个并发线程")
|
||||
|
||||
# 创建一个与输入顺序相同的结果列表
|
||||
@@ -262,8 +298,8 @@ def main():
|
||||
for idx, query in enumerate(examples):
|
||||
if query.strip() == "":
|
||||
continue
|
||||
process_query(recognizer, query, conversation_context, chat_history, previous_slots)
|
||||
# print(json.dumps(process_query(recognizer, query), ensure_ascii=False, indent=2))
|
||||
# process_query(recognizer, query, conversation_context, chat_history, previous_slots)
|
||||
print(json.dumps(process_query(recognizer, query), ensure_ascii=False, indent=2))
|
||||
|
||||
def setup_logging():
|
||||
# 配置日志输出到控制台
|
||||
|
||||
Reference in New Issue
Block a user