Files
QueryRewrite/rag2_0/demo/deduplicate_nouns_json.py
T

187 lines
6.9 KiB
Python
Executable File

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: deduplicate_json.py
Description: 对指定JSON文件进行去重并重新保存
"""
import os
import json
import argparse
import logging
from collections import defaultdict
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from dotenv import load_dotenv
from rag2_0.tool.ModelTool import OpenAiLLM
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
from rag2_0.intent_recognition.DataModels import Term
# 加载环境变量
load_dotenv()
class JsonDeduplicator:
"""JSON文件去重类"""
def __init__(self, input_path=None, output_path=None, key_field="name", max_workers=3):
"""初始化JSON去重器
Args:
input_path: 输入JSON文件路径
output_path: 去重后的输出文件路径
key_field: 用于去重的键字段名
max_workers: 线程池最大工作线程数
"""
self.INPUT_PATH = input_path
self.OUTPUT_PATH = output_path or input_path.replace('.json', '_deduplicated.json')
self.KEY_FIELD = key_field
self.MAX_WORKERS = max_workers
self.item_parser = PydanticOutputParser(pydantic_object=Term)
self.MERGE_PROMPT = '''
请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请:
- 同义词(synonymous)去重合并
- 描述(description)合并为更完整、简明的描述
- 保持输出格式为:
{output_format}
原始条目:
{items}
'''
# 配置LLM
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
llm_params = {"temperature": 0.3, "model": model_name}
if api_key:
llm_params["api_key"] = api_key
if base_url:
llm_params["base_url"] = base_url
self.llm = OpenAiLLM(**llm_params)
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def load_json_data(self):
"""读取JSON文件"""
try:
with open(self.INPUT_PATH, 'r', encoding='utf-8') as f:
data = json.load(f)
logging.info(f"从{self.INPUT_PATH}加载了{len(data)}条记录")
return data
except Exception as e:
logging.error(f"读取{self.INPUT_PATH}失败: {e}", exc_info=True)
return []
def group_items_by_key(self, items):
"""按指定键字段聚合项目"""
key_to_items = defaultdict(list)
for item in items:
key = item.get(self.KEY_FIELD, '').strip()
if key:
key_to_items[key].append(item)
return key_to_items
def merge_items_with_llm(self, key, item_list):
"""调用LLM合并具有相同键的项目,失败最多重试三次"""
items = json.dumps(item_list, ensure_ascii=False)
prompt = self.MERGE_PROMPT.format(
name=key,
items=items,
output_format=self.item_parser.get_format_instructions()
)
max_retries = 3
for attempt in range(1, max_retries + 1):
try:
response = self.llm.invoke(prompt, False)
parsed_output = self.item_parser.parse(response.content)
return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description}
except Exception as e:
if attempt == max_retries:
logging.warning(f"解析LLM合并结果失败: {e}")
return None
else:
import time
time.sleep(5*attempt)
def process_item(self, key_items_tuple):
"""处理单个键值对应的项目,用于线程池并行处理"""
key, item_list = key_items_tuple
try:
if len(item_list) == 1:
return item_list[0]
merged = self.merge_items_with_llm(key, item_list)
if merged:
return merged
else:
# 如果合并失败,返回第一个项目
return item_list[0]
except Exception as e:
logging.error(f"处理键 {key} 时出错: {e}", exc_info=True)
return item_list[0]
def deduplicate(self):
"""去重所有项目的入口方法"""
# 1. 读取JSON数据
all_items = self.load_json_data()
if not all_items:
return []
# 2. 按键字段聚合
key_to_items = self.group_items_by_key(all_items)
logging.info(f"共{len(key_to_items)}个唯一键")
# 3. 使用线程池并行处理
deduplicated_items = []
items_to_process = []
# 先处理只有一个项目的键(不需要合并)
for key, item_list in key_to_items.items():
if len(item_list) == 1:
deduplicated_items.append(item_list[0])
else:
items_to_process.append((key, item_list))
logging.info(f"共{len(deduplicated_items)}个单一项目,{len(items_to_process)}个需要合并的项目")
# 只对需要合并的项目使用线程池处理
if items_to_process:
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
# 使用tqdm显示进度
for result in tqdm(executor.map(self.process_item, items_to_process), total=len(items_to_process)):
deduplicated_items.append(result)
# 4. 保存去重结果
os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True)
with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f:
json.dump(deduplicated_items, f, ensure_ascii=False, indent=2)
logging.info(f"去重后结果已保存到: {self.OUTPUT_PATH}")
return deduplicated_items
def main():
"""主函数,解析命令行参数并执行去重"""
parser = argparse.ArgumentParser(description='对JSON文件进行去重')
input_path = 'data/wiki_extracted_nouns/技改检修计价通_nouns.json'
parser.add_argument('-i', '--input',default=input_path, help='输入JSON文件路径')
parser.add_argument('-o', '--output', help='输出JSON文件路径')
parser.add_argument('-k', '--key', default='name', help='用于去重的键字段名,默认为"name"')
parser.add_argument('-w', '--workers', type=int, default=30, help='线程池最大工作线程数,默认为2')
args = parser.parse_args()
deduplicator = JsonDeduplicator(
input_path=args.input,
output_path=args.output,
key_field=args.key,
max_workers=args.workers
)
deduplicator.deduplicate()
if __name__ == "__main__":
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
main()