187 lines
6.9 KiB
Python
Executable File
187 lines
6.9 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
File: deduplicate_json.py
|
|
Description: 对指定JSON文件进行去重并重新保存
|
|
"""
|
|
import os
|
|
import json
|
|
import argparse
|
|
import logging
|
|
from collections import defaultdict
|
|
from tqdm import tqdm
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from dotenv import load_dotenv
|
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
|
from langchain.output_parsers import PydanticOutputParser
|
|
from pydantic import BaseModel, Field
|
|
from rag2_0.intent_recognition.DataModels import Term
|
|
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
|
|
class JsonDeduplicator:
|
|
"""JSON文件去重类"""
|
|
|
|
def __init__(self, input_path=None, output_path=None, key_field="name", max_workers=3):
|
|
"""初始化JSON去重器
|
|
|
|
Args:
|
|
input_path: 输入JSON文件路径
|
|
output_path: 去重后的输出文件路径
|
|
key_field: 用于去重的键字段名
|
|
max_workers: 线程池最大工作线程数
|
|
"""
|
|
self.INPUT_PATH = input_path
|
|
self.OUTPUT_PATH = output_path or input_path.replace('.json', '_deduplicated.json')
|
|
self.KEY_FIELD = key_field
|
|
self.MAX_WORKERS = max_workers
|
|
self.item_parser = PydanticOutputParser(pydantic_object=Term)
|
|
self.MERGE_PROMPT = '''
|
|
请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请:
|
|
- 同义词(synonymous)去重合并
|
|
- 描述(description)合并为更完整、简明的描述
|
|
- 保持输出格式为:
|
|
{output_format}
|
|
原始条目:
|
|
{items}
|
|
'''
|
|
# 配置LLM
|
|
model_name = os.getenv("MODEL_NAME", "gpt-3.5-turbo")
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
base_url = os.getenv("OPENAI_API_BASE")
|
|
llm_params = {"temperature": 0.3, "model": model_name}
|
|
if api_key:
|
|
llm_params["api_key"] = api_key
|
|
if base_url:
|
|
llm_params["base_url"] = base_url
|
|
self.llm = OpenAiLLM(**llm_params)
|
|
|
|
# 配置日志
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
def load_json_data(self):
|
|
"""读取JSON文件"""
|
|
try:
|
|
with open(self.INPUT_PATH, 'r', encoding='utf-8') as f:
|
|
data = json.load(f)
|
|
logging.info(f"从{self.INPUT_PATH}加载了{len(data)}条记录")
|
|
return data
|
|
except Exception as e:
|
|
logging.error(f"读取{self.INPUT_PATH}失败: {e}", exc_info=True)
|
|
return []
|
|
|
|
def group_items_by_key(self, items):
|
|
"""按指定键字段聚合项目"""
|
|
key_to_items = defaultdict(list)
|
|
for item in items:
|
|
key = item.get(self.KEY_FIELD, '').strip()
|
|
if key:
|
|
key_to_items[key].append(item)
|
|
return key_to_items
|
|
|
|
def merge_items_with_llm(self, key, item_list):
|
|
"""调用LLM合并具有相同键的项目,失败最多重试三次"""
|
|
items = json.dumps(item_list, ensure_ascii=False)
|
|
prompt = self.MERGE_PROMPT.format(
|
|
name=key,
|
|
items=items,
|
|
output_format=self.item_parser.get_format_instructions()
|
|
)
|
|
|
|
max_retries = 3
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
response = self.llm.invoke(prompt, False)
|
|
parsed_output = self.item_parser.parse(response.content)
|
|
return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description}
|
|
except Exception as e:
|
|
if attempt == max_retries:
|
|
logging.warning(f"解析LLM合并结果失败: {e}")
|
|
return None
|
|
else:
|
|
import time
|
|
time.sleep(5*attempt)
|
|
|
|
def process_item(self, key_items_tuple):
|
|
"""处理单个键值对应的项目,用于线程池并行处理"""
|
|
key, item_list = key_items_tuple
|
|
try:
|
|
if len(item_list) == 1:
|
|
return item_list[0]
|
|
|
|
merged = self.merge_items_with_llm(key, item_list)
|
|
if merged:
|
|
return merged
|
|
else:
|
|
# 如果合并失败,返回第一个项目
|
|
return item_list[0]
|
|
except Exception as e:
|
|
logging.error(f"处理键 {key} 时出错: {e}", exc_info=True)
|
|
return item_list[0]
|
|
|
|
def deduplicate(self):
|
|
"""去重所有项目的入口方法"""
|
|
# 1. 读取JSON数据
|
|
all_items = self.load_json_data()
|
|
if not all_items:
|
|
return []
|
|
|
|
# 2. 按键字段聚合
|
|
key_to_items = self.group_items_by_key(all_items)
|
|
logging.info(f"共{len(key_to_items)}个唯一键")
|
|
|
|
# 3. 使用线程池并行处理
|
|
deduplicated_items = []
|
|
items_to_process = []
|
|
|
|
# 先处理只有一个项目的键(不需要合并)
|
|
for key, item_list in key_to_items.items():
|
|
if len(item_list) == 1:
|
|
deduplicated_items.append(item_list[0])
|
|
else:
|
|
items_to_process.append((key, item_list))
|
|
|
|
logging.info(f"共{len(deduplicated_items)}个单一项目,{len(items_to_process)}个需要合并的项目")
|
|
|
|
# 只对需要合并的项目使用线程池处理
|
|
if items_to_process:
|
|
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
|
# 使用tqdm显示进度
|
|
for result in tqdm(executor.map(self.process_item, items_to_process), total=len(items_to_process)):
|
|
deduplicated_items.append(result)
|
|
|
|
# 4. 保存去重结果
|
|
os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True)
|
|
with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f:
|
|
json.dump(deduplicated_items, f, ensure_ascii=False, indent=2)
|
|
logging.info(f"去重后结果已保存到: {self.OUTPUT_PATH}")
|
|
|
|
return deduplicated_items
|
|
|
|
|
|
def main():
|
|
"""主函数,解析命令行参数并执行去重"""
|
|
parser = argparse.ArgumentParser(description='对JSON文件进行去重')
|
|
input_path = 'data/wiki_extracted_nouns/技改检修计价通_nouns.json'
|
|
|
|
parser.add_argument('-i', '--input',default=input_path, help='输入JSON文件路径')
|
|
parser.add_argument('-o', '--output', help='输出JSON文件路径')
|
|
parser.add_argument('-k', '--key', default='name', help='用于去重的键字段名,默认为"name"')
|
|
parser.add_argument('-w', '--workers', type=int, default=30, help='线程池最大工作线程数,默认为2')
|
|
|
|
args = parser.parse_args()
|
|
|
|
deduplicator = JsonDeduplicator(
|
|
input_path=args.input,
|
|
output_path=args.output,
|
|
key_field=args.key,
|
|
max_workers=args.workers
|
|
)
|
|
deduplicator.deduplicate()
|
|
|
|
if __name__ == "__main__":
|
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
|
main() |