删除重复的名词,指定问题改写时的同义词替换
This commit is contained in:
@@ -0,0 +1,187 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: deduplicate_json.py
|
||||
Description: 对指定JSON文件进行去重并重新保存
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from tqdm import tqdm
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from pydantic import BaseModel, Field
|
||||
from rag2_0.intent_recognition.DataModels import Term
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class JsonDeduplicator:
|
||||
"""JSON文件去重类"""
|
||||
|
||||
def __init__(self, input_path=None, output_path=None, key_field="name", max_workers=3):
|
||||
"""初始化JSON去重器
|
||||
|
||||
Args:
|
||||
input_path: 输入JSON文件路径
|
||||
output_path: 去重后的输出文件路径
|
||||
key_field: 用于去重的键字段名
|
||||
max_workers: 线程池最大工作线程数
|
||||
"""
|
||||
self.INPUT_PATH = input_path
|
||||
self.OUTPUT_PATH = output_path or input_path.replace('.json', '_deduplicated.json')
|
||||
self.KEY_FIELD = key_field
|
||||
self.MAX_WORKERS = max_workers
|
||||
self.item_parser = PydanticOutputParser(pydantic_object=Term)
|
||||
self.MERGE_PROMPT = '''
|
||||
请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请:
|
||||
- 同义词(synonymous)去重合并
|
||||
- 描述(description)合并为更完整、简明的描述
|
||||
- 保持输出格式为:
|
||||
{output_format}
|
||||
原始条目:
|
||||
{items}
|
||||
'''
|
||||
# 配置LLM
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
llm_params = {"temperature": 0.3, "model": model_name}
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
self.llm = OpenAiLLM(**llm_params)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
def load_json_data(self):
|
||||
"""读取JSON文件"""
|
||||
try:
|
||||
with open(self.INPUT_PATH, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
logging.info(f"从{self.INPUT_PATH}加载了{len(data)}条记录")
|
||||
return data
|
||||
except Exception as e:
|
||||
logging.error(f"读取{self.INPUT_PATH}失败: {e}")
|
||||
return []
|
||||
|
||||
def group_items_by_key(self, items):
|
||||
"""按指定键字段聚合项目"""
|
||||
key_to_items = defaultdict(list)
|
||||
for item in items:
|
||||
key = item.get(self.KEY_FIELD, '').strip()
|
||||
if key:
|
||||
key_to_items[key].append(item)
|
||||
return key_to_items
|
||||
|
||||
def merge_items_with_llm(self, key, item_list):
|
||||
"""调用LLM合并具有相同键的项目,失败最多重试三次"""
|
||||
items = json.dumps(item_list, ensure_ascii=False)
|
||||
prompt = self.MERGE_PROMPT.format(
|
||||
name=key,
|
||||
items=items,
|
||||
output_format=self.item_parser.get_format_instructions()
|
||||
)
|
||||
|
||||
max_retries = 3
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
response = self.llm.invoke(prompt, False)
|
||||
parsed_output = self.item_parser.parse(response.content)
|
||||
return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description}
|
||||
except Exception as e:
|
||||
if attempt == max_retries:
|
||||
logging.warning(f"解析LLM合并结果失败: {e}")
|
||||
return None
|
||||
else:
|
||||
import time
|
||||
time.sleep(5*attempt)
|
||||
|
||||
def process_item(self, key_items_tuple):
|
||||
"""处理单个键值对应的项目,用于线程池并行处理"""
|
||||
key, item_list = key_items_tuple
|
||||
try:
|
||||
if len(item_list) == 1:
|
||||
return item_list[0]
|
||||
|
||||
merged = self.merge_items_with_llm(key, item_list)
|
||||
if merged:
|
||||
return merged
|
||||
else:
|
||||
# 如果合并失败,返回第一个项目
|
||||
return item_list[0]
|
||||
except Exception as e:
|
||||
logging.error(f"处理键 {key} 时出错: {e}")
|
||||
return item_list[0]
|
||||
|
||||
def deduplicate(self):
|
||||
"""去重所有项目的入口方法"""
|
||||
# 1. 读取JSON数据
|
||||
all_items = self.load_json_data()
|
||||
if not all_items:
|
||||
return []
|
||||
|
||||
# 2. 按键字段聚合
|
||||
key_to_items = self.group_items_by_key(all_items)
|
||||
logging.info(f"共{len(key_to_items)}个唯一键")
|
||||
|
||||
# 3. 使用线程池并行处理
|
||||
deduplicated_items = []
|
||||
items_to_process = []
|
||||
|
||||
# 先处理只有一个项目的键(不需要合并)
|
||||
for key, item_list in key_to_items.items():
|
||||
if len(item_list) == 1:
|
||||
deduplicated_items.append(item_list[0])
|
||||
else:
|
||||
items_to_process.append((key, item_list))
|
||||
|
||||
logging.info(f"共{len(deduplicated_items)}个单一项目,{len(items_to_process)}个需要合并的项目")
|
||||
|
||||
# 只对需要合并的项目使用线程池处理
|
||||
if items_to_process:
|
||||
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
||||
# 使用tqdm显示进度
|
||||
for result in tqdm(executor.map(self.process_item, items_to_process), total=len(items_to_process)):
|
||||
deduplicated_items.append(result)
|
||||
|
||||
# 4. 保存去重结果
|
||||
os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True)
|
||||
with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f:
|
||||
json.dump(deduplicated_items, f, ensure_ascii=False, indent=2)
|
||||
logging.info(f"去重后结果已保存到: {self.OUTPUT_PATH}")
|
||||
|
||||
return deduplicated_items
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,解析命令行参数并执行去重"""
|
||||
parser = argparse.ArgumentParser(description='对JSON文件进行去重')
|
||||
input_path = 'data/wiki_extracted_nouns/技改检修计价通_nouns.json'
|
||||
|
||||
parser.add_argument('-i', '--input',default=input_path, help='输入JSON文件路径')
|
||||
parser.add_argument('-o', '--output', help='输出JSON文件路径')
|
||||
parser.add_argument('-k', '--key', default='name', help='用于去重的键字段名,默认为"name"')
|
||||
parser.add_argument('-w', '--workers', type=int, default=30, help='线程池最大工作线程数,默认为2')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
deduplicator = JsonDeduplicator(
|
||||
input_path=args.input,
|
||||
output_path=args.output,
|
||||
key_field=args.key,
|
||||
max_workers=args.workers
|
||||
)
|
||||
deduplicator.deduplicate()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
main()
|
||||
Reference in New Issue
Block a user