Files
QueryRewrite/rag2_0/demo/merge_nouns_with_llm.py
T

201 lines
8.1 KiB
Python
Executable File

#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: merge_nouns_with_llm.py
Description: 合并多个nouns.json中的同名专业名词,利用LLM生成唯一合并结果
"""
import os
import json
import glob
from concurrent.futures import ThreadPoolExecutor
from collections import defaultdict
from dotenv import load_dotenv
from rag2_0.tool.ModelTool import OpenAiLLM
from rag2_0.intent_recognition.DataModels import Term
import logging
from langchain.output_parsers import PydanticOutputParser
from tqdm import tqdm
import time
# 加载环境变量
load_dotenv()
class TermMerger:
"""专业名词合并类,用于合并多个数据源中的同名专业名词"""
def __init__(self, input_dir=None, output_path=None, max_workers=3):
"""初始化名词合并器
Args:
input_dir: 包含nouns.json文件的目录路径
output_path: 合并结果的输出文件路径
max_workers: 线程池最大工作线程数
"""
self.EXTRACTED_NOUNS_DIR = input_dir
self.OUTPUT_PATH = output_path
self.MAX_WORKERS = max_workers
self.terms_parser = PydanticOutputParser(pydantic_object=Term)
self.MERGE_PROMPT = '''
请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请:
- 同义词(synonymous)去重合并
- 描述(description)合并为更完整、简明的描述
- 保持输出格式为:
{output_format}
原始条目:
{items}
'''
# 配置LLM
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
llm_params = {"temperature": 0.3, "model": model_name}
if api_key:
llm_params["api_key"] = api_key
if base_url:
llm_params["base_url"] = base_url
self.llm = OpenAiLLM(**llm_params)
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def load_all_terms(self):
"""读取目录下所有nouns.json,返回所有Term列表"""
all_terms = []
for file in glob.glob(os.path.join(self.EXTRACTED_NOUNS_DIR, '*_nouns.json')):
with open(file, 'r', encoding='utf-8') as f:
try:
file_terms = json.load(f)
new_terms = [{"name": term["name"].upper(), "synonymous": term["synonymous"], "description": term["description"]} for term in file_terms]
all_terms.extend(new_terms)
logging.info(f"加载{file},共{len(new_terms)}条")
except Exception as e:
logging.warning(f"读取{file}失败: {e}")
# 加载suffix_keywords.json文件
# suffix_keywords_path = os.path.join(os.path.dirname(os.path.dirname(self.EXTRACTED_NOUNS_DIR)), 'data', 'nouns', 'suffix_keywords.json')
# if os.path.exists(suffix_keywords_path):
# try:
# with open(suffix_keywords_path, 'r', encoding='utf-8') as f:
# suffix_terms = json.load(f)
# suffix_terms = [{"name": term["name"].upper(), "synonymous": "", "description": ""} for term in suffix_terms]
# all_terms.extend(suffix_terms)
# logging.info(f"加载{suffix_keywords_path},共{len(suffix_terms)}条")
# except Exception as e:
# logging.warning(f"读取{suffix_keywords_path}失败: {e}")
return all_terms
def group_terms_by_name(self, terms):
"""按name聚合Term"""
name2terms = defaultdict(list)
for term in terms:
name = term.get('name', '').strip()
if name:
name2terms[name].append(term)
return name2terms
def merge_terms_with_llm(self, name, term_list):
"""调用LLM合并同名Term,失败最多重试三次"""
items = json.dumps(term_list, ensure_ascii=False)
prompt = self.MERGE_PROMPT.format(name=name, items=items, output_format=self.terms_parser.get_format_instructions())
max_retries = 3
for attempt in range(1, max_retries + 1):
try:
response = self.llm.invoke(prompt, False)
parsed_output = self.terms_parser.parse(response.content)
return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description}
except Exception as e:
if attempt == max_retries:
logging.warning(f"解析LLM合并结果失败: {e}")
return None
else:
time.sleep(10*attempt)
def process_term(self, name_terms_tuple):
"""处理单个词条,用于线程池并行处理"""
name, term_list = name_terms_tuple
try:
merged = self.merge_terms_with_llm(name, term_list)
if merged:
return merged
else:
return term_list[0]
except Exception as e:
logging.error(f"处理词条 {name} 时出错: {e}", exc_info=True)
return term_list[0]
def merge(self):
"""合并所有词条的入口方法"""
# 1. 读取所有术语
all_terms = self.load_all_terms()
logging.info(f"共加载{len(all_terms)}条术语")
# 2. 按名称聚合
name2terms = self.group_terms_by_name(all_terms)
logging.info(f"共{len(name2terms)}个唯一名词")
# 3. 使用线程池并行处理
merged_terms = []
items_to_process = []
# 先处理只有一个条目的词条(不需要合并)
for name, term_list in name2terms.items():
if len(term_list) == 1:
merged_terms.append(term_list[0])
else:
items_to_process.append((name, term_list))
logging.info(f"共{len(merged_terms)}个单一条目,{len(items_to_process)}个需要合并的条目")
# 只对需要合并的词条使用线程池处理
if items_to_process:
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
# 使用tqdm显示进度
for result in tqdm(executor.map(self.process_term, items_to_process), total=len(items_to_process)):
merged_terms.append(result)
# 4. 去重
merged_terms = self.deduplicate_synonymous_name(merged_terms)
# 4. 保存合并结果
os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True)
with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f:
json.dump(merged_terms, f, ensure_ascii=False, indent=2)
logging.info(f"合并后结果已保存到: {self.OUTPUT_PATH}")
return merged_terms
def deduplicate_synonymous_name(self, terms):
# 1. 删除name字段重复的条目
unique_names = set()
unique_data = []
for item in terms:
if item["name"] not in unique_names:
unique_names.add(item["name"])
unique_data.append(item)
# 如果重复,则跳过该条目
# 2. 如果A条目的某一个synonymou字段是B条目的name,则删除A条目中的对应的synonymou
name_set = {item["name"] for item in unique_data}
for item in unique_data:
# 过滤掉synonymous中与其他条目name重复的部分
filtered_synonymous = [syn for syn in item["synonymous"] if syn not in name_set]
item["synonymous"] = filtered_synonymous
return unique_data
def main():
"""主函数,创建TermMerger实例并执行合并"""
cur_path = os.path.dirname(__file__)
input_dir = os.path.abspath(os.path.join(cur_path, '../../data/wiki_extracted_nouns'))
output_path = os.path.join(cur_path, "..", "..", "data", "nouns", 'merged_nouns.json')
merger = TermMerger(input_dir=input_dir, output_path=output_path, max_workers=20)
merger.merge()
if __name__ == "__main__":
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
main()