上传问题改写、意图识别模块代码
This commit is contained in:
@@ -0,0 +1,178 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
File: merge_nouns_with_llm.py
|
||||
Description: 合并多个nouns.json中的同名专业名词,利用LLM生成唯一合并结果
|
||||
"""
|
||||
import os
|
||||
import json
|
||||
import glob
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from collections import defaultdict
|
||||
from dotenv import load_dotenv
|
||||
from rag2_0.tool.ModelTool import OpenAiLLM
|
||||
from rag2_0.intent_recognition.DataModels import Term
|
||||
import logging
|
||||
from langchain.output_parsers import PydanticOutputParser
|
||||
from tqdm import tqdm
|
||||
import time
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
class TermMerger:
|
||||
"""专业名词合并类,用于合并多个数据源中的同名专业名词"""
|
||||
|
||||
def __init__(self, input_dir=None, output_path=None, max_workers=3):
|
||||
"""初始化名词合并器
|
||||
|
||||
Args:
|
||||
input_dir: 包含nouns.json文件的目录路径
|
||||
output_path: 合并结果的输出文件路径
|
||||
max_workers: 线程池最大工作线程数
|
||||
"""
|
||||
self.EXTRACTED_NOUNS_DIR = input_dir
|
||||
self.OUTPUT_PATH = output_path
|
||||
self.MAX_WORKERS = max_workers
|
||||
self.terms_parser = PydanticOutputParser(pydantic_object=Term)
|
||||
self.MERGE_PROMPT = '''
|
||||
请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请:
|
||||
- 同义词(synonymous)去重合并
|
||||
- 描述(description)合并为更完整、简明的描述
|
||||
- 保持输出格式为:
|
||||
{output_format}
|
||||
原始条目:
|
||||
{items}
|
||||
'''
|
||||
# 配置LLM
|
||||
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
base_url = os.getenv("OPENAI_API_BASE")
|
||||
llm_params = {"temperature": 0.3, "model": model_name}
|
||||
if api_key:
|
||||
llm_params["api_key"] = api_key
|
||||
if base_url:
|
||||
llm_params["base_url"] = base_url
|
||||
self.llm = OpenAiLLM(**llm_params)
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
def load_all_terms(self):
|
||||
"""读取目录下所有nouns.json,返回所有Term列表"""
|
||||
all_terms = []
|
||||
for file in glob.glob(os.path.join(self.EXTRACTED_NOUNS_DIR, '*_nouns.json')):
|
||||
with open(file, 'r', encoding='utf-8') as f:
|
||||
try:
|
||||
file_terms = json.load(f)
|
||||
new_terms = [{"name": term["name"].upper(), "synonymous": term["synonymous"], "description": term["description"]} for term in file_terms]
|
||||
all_terms.extend(new_terms)
|
||||
logging.info(f"加载{file},共{len(new_terms)}条")
|
||||
except Exception as e:
|
||||
logging.warning(f"读取{file}失败: {e}")
|
||||
|
||||
# 加载suffix_keywords.json文件
|
||||
suffix_keywords_path = os.path.join(os.path.dirname(os.path.dirname(self.EXTRACTED_NOUNS_DIR)), 'data', 'nouns', 'suffix_keywords.json')
|
||||
if os.path.exists(suffix_keywords_path):
|
||||
try:
|
||||
with open(suffix_keywords_path, 'r', encoding='utf-8') as f:
|
||||
suffix_terms = json.load(f)
|
||||
suffix_terms = [{"name": term["name"].upper(), "synonymous": "", "description": ""} for term in suffix_terms]
|
||||
all_terms.extend(suffix_terms)
|
||||
logging.info(f"加载{suffix_keywords_path},共{len(suffix_terms)}条")
|
||||
except Exception as e:
|
||||
logging.warning(f"读取{suffix_keywords_path}失败: {e}")
|
||||
|
||||
return all_terms
|
||||
|
||||
def group_terms_by_name(self, terms):
|
||||
"""按name聚合Term"""
|
||||
name2terms = defaultdict(list)
|
||||
for term in terms:
|
||||
name = term.get('name', '').strip()
|
||||
if name:
|
||||
name2terms[name].append(term)
|
||||
return name2terms
|
||||
|
||||
def merge_terms_with_llm(self, name, term_list):
|
||||
"""调用LLM合并同名Term,失败最多重试三次"""
|
||||
items = json.dumps(term_list, ensure_ascii=False)
|
||||
prompt = self.MERGE_PROMPT.format(name=name, items=items, output_format=self.terms_parser.get_format_instructions())
|
||||
|
||||
max_retries = 3
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
response = self.llm.invoke(prompt, False)
|
||||
parsed_output = self.terms_parser.parse(response.content)
|
||||
return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description}
|
||||
except Exception as e:
|
||||
if attempt == max_retries:
|
||||
logging.warning(f"解析LLM合并结果失败: {e}")
|
||||
return None
|
||||
else:
|
||||
time.sleep(10*attempt)
|
||||
|
||||
def process_term(self, name_terms_tuple):
|
||||
"""处理单个词条,用于线程池并行处理"""
|
||||
name, term_list = name_terms_tuple
|
||||
try:
|
||||
merged = self.merge_terms_with_llm(name, term_list)
|
||||
if merged:
|
||||
return merged
|
||||
else:
|
||||
return term_list[0]
|
||||
except Exception as e:
|
||||
logging.error(f"处理词条 {name} 时出错: {e}")
|
||||
return term_list[0]
|
||||
|
||||
def merge(self):
|
||||
"""合并所有词条的入口方法"""
|
||||
# 1. 读取所有术语
|
||||
all_terms = self.load_all_terms()
|
||||
logging.info(f"共加载{len(all_terms)}条术语")
|
||||
|
||||
# 2. 按名称聚合
|
||||
name2terms = self.group_terms_by_name(all_terms)
|
||||
logging.info(f"共{len(name2terms)}个唯一名词")
|
||||
|
||||
# 3. 使用线程池并行处理
|
||||
merged_terms = []
|
||||
items_to_process = []
|
||||
|
||||
# 先处理只有一个条目的词条(不需要合并)
|
||||
for name, term_list in name2terms.items():
|
||||
if len(term_list) == 1:
|
||||
merged_terms.append(term_list[0])
|
||||
else:
|
||||
items_to_process.append((name, term_list))
|
||||
|
||||
logging.info(f"共{len(merged_terms)}个单一条目,{len(items_to_process)}个需要合并的条目")
|
||||
|
||||
# 只对需要合并的词条使用线程池处理
|
||||
if items_to_process:
|
||||
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
|
||||
# 使用tqdm显示进度
|
||||
for result in tqdm(executor.map(self.process_term, items_to_process), total=len(items_to_process)):
|
||||
merged_terms.append(result)
|
||||
|
||||
# 4. 保存合并结果
|
||||
os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True)
|
||||
with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f:
|
||||
json.dump(merged_terms, f, ensure_ascii=False, indent=2)
|
||||
logging.info(f"合并后结果已保存到: {self.OUTPUT_PATH}")
|
||||
|
||||
return merged_terms
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,创建TermMerger实例并执行合并"""
|
||||
|
||||
cur_path = os.path.dirname(__file__)
|
||||
input_dir = os.path.abspath(os.path.join(cur_path, '../../data/wiki_extracted_nouns'))
|
||||
output_path = os.path.join(cur_path, "..", "..", "data", "nouns", 'merged_nouns.json')
|
||||
merger = TermMerger(input_dir=input_dir, output_path=output_path, max_workers=2)
|
||||
merger.merge()
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.getLogger('httpx').setLevel(logging.WARNING)
|
||||
logging.getLogger('openai').setLevel(logging.WARNING)
|
||||
main()
|
||||
Reference in New Issue
Block a user