178 lines
7.2 KiB
Python
178 lines
7.2 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
File: merge_nouns_with_llm.py
|
|
Description: 合并多个nouns.json中的同名专业名词,利用LLM生成唯一合并结果
|
|
"""
|
|
import os
|
|
import json
|
|
import glob
|
|
from concurrent.futures import ThreadPoolExecutor
|
|
from collections import defaultdict
|
|
from dotenv import load_dotenv
|
|
from rag2_0.tool.ModelTool import OpenAiLLM
|
|
from rag2_0.intent_recognition.DataModels import Term
|
|
import logging
|
|
from langchain.output_parsers import PydanticOutputParser
|
|
from tqdm import tqdm
|
|
import time
|
|
# 加载环境变量
|
|
load_dotenv()
|
|
|
|
class TermMerger:
|
|
"""专业名词合并类,用于合并多个数据源中的同名专业名词"""
|
|
|
|
def __init__(self, input_dir=None, output_path=None, max_workers=3):
|
|
"""初始化名词合并器
|
|
|
|
Args:
|
|
input_dir: 包含nouns.json文件的目录路径
|
|
output_path: 合并结果的输出文件路径
|
|
max_workers: 线程池最大工作线程数
|
|
"""
|
|
self.EXTRACTED_NOUNS_DIR = input_dir
|
|
self.OUTPUT_PATH = output_path
|
|
self.MAX_WORKERS = max_workers
|
|
self.terms_parser = PydanticOutputParser(pydantic_object=Term)
|
|
self.MERGE_PROMPT = '''
|
|
请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请:
|
|
- 同义词(synonymous)去重合并
|
|
- 描述(description)合并为更完整、简明的描述
|
|
- 保持输出格式为:
|
|
{output_format}
|
|
原始条目:
|
|
{items}
|
|
'''
|
|
# 配置LLM
|
|
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
|
api_key = os.getenv("OPENAI_API_KEY")
|
|
base_url = os.getenv("OPENAI_API_BASE")
|
|
llm_params = {"temperature": 0.3, "model": model_name}
|
|
if api_key:
|
|
llm_params["api_key"] = api_key
|
|
if base_url:
|
|
llm_params["base_url"] = base_url
|
|
self.llm = OpenAiLLM(**llm_params)
|
|
|
|
# 配置日志
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
|
def load_all_terms(self):
|
|
"""读取目录下所有nouns.json,返回所有Term列表"""
|
|
all_terms = []
|
|
for file in glob.glob(os.path.join(self.EXTRACTED_NOUNS_DIR, '*_nouns.json')):
|
|
with open(file, 'r', encoding='utf-8') as f:
|
|
try:
|
|
file_terms = json.load(f)
|
|
new_terms = [{"name": term["name"].upper(), "synonymous": term["synonymous"], "description": term["description"]} for term in file_terms]
|
|
all_terms.extend(new_terms)
|
|
logging.info(f"加载{file},共{len(new_terms)}条")
|
|
except Exception as e:
|
|
logging.warning(f"读取{file}失败: {e}")
|
|
|
|
# 加载suffix_keywords.json文件
|
|
suffix_keywords_path = os.path.join(os.path.dirname(os.path.dirname(self.EXTRACTED_NOUNS_DIR)), 'data', 'nouns', 'suffix_keywords.json')
|
|
if os.path.exists(suffix_keywords_path):
|
|
try:
|
|
with open(suffix_keywords_path, 'r', encoding='utf-8') as f:
|
|
suffix_terms = json.load(f)
|
|
suffix_terms = [{"name": term["name"].upper(), "synonymous": "", "description": ""} for term in suffix_terms]
|
|
all_terms.extend(suffix_terms)
|
|
logging.info(f"加载{suffix_keywords_path},共{len(suffix_terms)}条")
|
|
except Exception as e:
|
|
logging.warning(f"读取{suffix_keywords_path}失败: {e}")
|
|
|
|
return all_terms
|
|
|
|
def group_terms_by_name(self, terms):
|
|
"""按name聚合Term"""
|
|
name2terms = defaultdict(list)
|
|
for term in terms:
|
|
name = term.get('name', '').strip()
|
|
if name:
|
|
name2terms[name].append(term)
|
|
return name2terms
|
|
|
|
def merge_terms_with_llm(self, name, term_list):
|
|
"""调用LLM合并同名Term,失败最多重试三次"""
|
|
items = json.dumps(term_list, ensure_ascii=False)
|
|
prompt = self.MERGE_PROMPT.format(name=name, items=items, output_format=self.terms_parser.get_format_instructions())
|
|
|
|
max_retries = 3
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
response = self.llm.invoke(prompt, False)
|
|
parsed_output = self.terms_parser.parse(response.content)
|
|
return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description}
|
|
except Exception as e:
|
|
if attempt == max_retries:
|
|
logging.warning(f"解析LLM合并结果失败: {e}")
|
|
return None
|
|
else:
|
|
time.sleep(10*attempt)
|
|
|
|
def process_term(self, name_terms_tuple):
|
|
"""处理单个词条,用于线程池并行处理"""
|
|
name, term_list = name_terms_tuple
|
|
try:
|
|
merged = self.merge_terms_with_llm(name, term_list)
|
|
if merged:
|
|
return merged
|
|
else:
|
|
return term_list[0]
|
|
except Exception as e:
|
|
logging.error(f"处理词条 {name} 时出错: {e}")
|
|
return term_list[0]
|
|
|
|
def merge(self):
|
|
"""合并所有词条的入口方法"""
|
|
# 1. 读取所有术语
|
|
all_terms = self.load_all_terms()
|
|
logging.info(f"共加载{len(all_terms)}条术语")
|
|
|
|
# 2. 按名称聚合
|
|
name2terms = self.group_terms_by_name(all_terms)
|
|
logging.info(f"共{len(name2terms)}个唯一名词")
|
|
|
|
# 3. 使用线程池并行处理
|
|
merged_terms = []
|
|
items_to_process = []
|
|
|
|
# 先处理只有一个条目的词条(不需要合并)
|
|
for name, term_list in name2terms.items():
|
|
if len(term_list) == 1:
|
|
merged_terms.append(term_list[0])
|
|
else:
|
|
items_to_process.append((name, term_list))
|
|
|
|
logging.info(f"共{len(merged_terms)}个单一条目,{len(items_to_process)}个需要合并的条目")
|
|
|
|
# 只对需要合并的词条使用线程池处理
|
|
if items_to_process:
|
|
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
|
# 使用tqdm显示进度
|
|
for result in tqdm(executor.map(self.process_term, items_to_process), total=len(items_to_process)):
|
|
merged_terms.append(result)
|
|
|
|
# 4. 保存合并结果
|
|
os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True)
|
|
with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f:
|
|
json.dump(merged_terms, f, ensure_ascii=False, indent=2)
|
|
logging.info(f"合并后结果已保存到: {self.OUTPUT_PATH}")
|
|
|
|
return merged_terms
|
|
|
|
|
|
def main():
|
|
"""主函数,创建TermMerger实例并执行合并"""
|
|
|
|
cur_path = os.path.dirname(__file__)
|
|
input_dir = os.path.abspath(os.path.join(cur_path, '../../data/wiki_extracted_nouns'))
|
|
output_path = os.path.join(cur_path, "..", "..", "data", "nouns", 'merged_nouns.json')
|
|
merger = TermMerger(input_dir=input_dir, output_path=output_path, max_workers=2)
|
|
merger.merge()
|
|
|
|
if __name__ == "__main__":
|
|
logging.getLogger('httpx').setLevel(logging.WARNING)
|
|
logging.getLogger('openai').setLevel(logging.WARNING)
|
|
main() |