#!/usr/bin/env python # -*- coding: utf-8 -*- """ File: deduplicate_json.py Description: 对指定JSON文件进行去重并重新保存 """ import os import json import argparse import logging from collections import defaultdict from tqdm import tqdm from concurrent.futures import ThreadPoolExecutor from dotenv import load_dotenv from rag2_0.tool.ModelTool import OpenAiLLM from langchain.output_parsers import PydanticOutputParser from pydantic import BaseModel, Field from rag2_0.intent_recognition.DataModels import Term # 加载环境变量 load_dotenv() class JsonDeduplicator: """JSON文件去重类""" def __init__(self, input_path=None, output_path=None, key_field="name", max_workers=3): """初始化JSON去重器 Args: input_path: 输入JSON文件路径 output_path: 去重后的输出文件路径 key_field: 用于去重的键字段名 max_workers: 线程池最大工作线程数 """ self.INPUT_PATH = input_path self.OUTPUT_PATH = output_path or input_path.replace('.json', '_deduplicated.json') self.KEY_FIELD = key_field self.MAX_WORKERS = max_workers self.item_parser = PydanticOutputParser(pydantic_object=Term) self.MERGE_PROMPT = ''' 请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请: - 同义词(synonymous)去重合并 - 描述(description)合并为更完整、简明的描述 - 保持输出格式为: {output_format} 原始条目: {items} ''' # 配置LLM model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo") api_key = os.getenv("OPENAI_API_KEY") base_url = os.getenv("OPENAI_API_BASE") llm_params = {"temperature": 0.3, "model": model_name} if api_key: llm_params["api_key"] = api_key if base_url: llm_params["base_url"] = base_url self.llm = OpenAiLLM(**llm_params) # 配置日志 logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def load_json_data(self): """读取JSON文件""" try: with open(self.INPUT_PATH, 'r', encoding='utf-8') as f: data = json.load(f) logging.info(f"从{self.INPUT_PATH}加载了{len(data)}条记录") return data except Exception as e: logging.error(f"读取{self.INPUT_PATH}失败: {e}", exc_info=True) return [] def group_items_by_key(self, items): """按指定键字段聚合项目""" key_to_items = defaultdict(list) for item in items: key = item.get(self.KEY_FIELD, '').strip() if key: key_to_items[key].append(item) return key_to_items def merge_items_with_llm(self, key, item_list): """调用LLM合并具有相同键的项目,失败最多重试三次""" items = json.dumps(item_list, ensure_ascii=False) prompt = self.MERGE_PROMPT.format( name=key, items=items, output_format=self.item_parser.get_format_instructions() ) max_retries = 3 for attempt in range(1, max_retries + 1): try: response = self.llm.invoke(prompt, False) parsed_output = self.item_parser.parse(response.content) return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description} except Exception as e: if attempt == max_retries: logging.warning(f"解析LLM合并结果失败: {e}") return None else: import time time.sleep(5*attempt) def process_item(self, key_items_tuple): """处理单个键值对应的项目,用于线程池并行处理""" key, item_list = key_items_tuple try: if len(item_list) == 1: return item_list[0] merged = self.merge_items_with_llm(key, item_list) if merged: return merged else: # 如果合并失败,返回第一个项目 return item_list[0] except Exception as e: logging.error(f"处理键 {key} 时出错: {e}", exc_info=True) return item_list[0] def deduplicate(self): """去重所有项目的入口方法""" # 1. 读取JSON数据 all_items = self.load_json_data() if not all_items: return [] # 2. 按键字段聚合 key_to_items = self.group_items_by_key(all_items) logging.info(f"共{len(key_to_items)}个唯一键") # 3. 使用线程池并行处理 deduplicated_items = [] items_to_process = [] # 先处理只有一个项目的键(不需要合并) for key, item_list in key_to_items.items(): if len(item_list) == 1: deduplicated_items.append(item_list[0]) else: items_to_process.append((key, item_list)) logging.info(f"共{len(deduplicated_items)}个单一项目,{len(items_to_process)}个需要合并的项目") # 只对需要合并的项目使用线程池处理 if items_to_process: with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor: # 使用tqdm显示进度 for result in tqdm(executor.map(self.process_item, items_to_process), total=len(items_to_process)): deduplicated_items.append(result) # 4. 保存去重结果 os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True) with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f: json.dump(deduplicated_items, f, ensure_ascii=False, indent=2) logging.info(f"去重后结果已保存到: {self.OUTPUT_PATH}") return deduplicated_items def main(): """主函数,解析命令行参数并执行去重""" parser = argparse.ArgumentParser(description='对JSON文件进行去重') input_path = 'data/wiki_extracted_nouns/技改检修计价通_nouns.json' parser.add_argument('-i', '--input',default=input_path, help='输入JSON文件路径') parser.add_argument('-o', '--output', help='输出JSON文件路径') parser.add_argument('-k', '--key', default='name', help='用于去重的键字段名,默认为"name"') parser.add_argument('-w', '--workers', type=int, default=30, help='线程池最大工作线程数,默认为2') args = parser.parse_args() deduplicator = JsonDeduplicator( input_path=args.input, output_path=args.output, key_field=args.key, max_workers=args.workers ) deduplicator.deduplicate() if __name__ == "__main__": logging.getLogger('httpx').setLevel(logging.WARNING) logging.getLogger('openai').setLevel(logging.WARNING) main()