删除重复的名词,指定问题改写时的同义词替换

This commit is contained in:
2025-05-27 15:19:48 +08:00
parent 3dfa8c8a8a
commit 670de2f758
13 changed files with 17093 additions and 51385 deletions
+6497 -6569
View File
File diff suppressed because it is too large Load Diff
Binary file not shown.
Binary file not shown.
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
+187
View File
@@ -0,0 +1,187 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
File: deduplicate_json.py
Description: 对指定JSON文件进行去重并重新保存
"""
import os
import json
import argparse
import logging
from collections import defaultdict
from tqdm import tqdm
from concurrent.futures import ThreadPoolExecutor
from dotenv import load_dotenv
from rag2_0.tool.ModelTool import OpenAiLLM
from langchain.output_parsers import PydanticOutputParser
from pydantic import BaseModel, Field
from rag2_0.intent_recognition.DataModels import Term
# 加载环境变量
load_dotenv()
class JsonDeduplicator:
"""JSON文件去重类"""
def __init__(self, input_path=None, output_path=None, key_field="name", max_workers=3):
"""初始化JSON去重器
Args:
input_path: 输入JSON文件路径
output_path: 去重后的输出文件路径
key_field: 用于去重的键字段名
max_workers: 线程池最大工作线程数
"""
self.INPUT_PATH = input_path
self.OUTPUT_PATH = output_path or input_path.replace('.json', '_deduplicated.json')
self.KEY_FIELD = key_field
self.MAX_WORKERS = max_workers
self.item_parser = PydanticOutputParser(pydantic_object=Term)
self.MERGE_PROMPT = '''
请将以下多个描述相同名词"{name}"的条目合并为一个,合并时请:
- 同义词(synonymous)去重合并
- 描述(description)合并为更完整、简明的描述
- 保持输出格式为:
{output_format}
原始条目:
{items}
'''
# 配置LLM
model_name = os.getenv("LLM_MODEL_NAME", "gpt-3.5-turbo")
api_key = os.getenv("OPENAI_API_KEY")
base_url = os.getenv("OPENAI_API_BASE")
llm_params = {"temperature": 0.3, "model": model_name}
if api_key:
llm_params["api_key"] = api_key
if base_url:
llm_params["base_url"] = base_url
self.llm = OpenAiLLM(**llm_params)
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def load_json_data(self):
"""读取JSON文件"""
try:
with open(self.INPUT_PATH, 'r', encoding='utf-8') as f:
data = json.load(f)
logging.info(f"{self.INPUT_PATH}加载了{len(data)}条记录")
return data
except Exception as e:
logging.error(f"读取{self.INPUT_PATH}失败: {e}")
return []
def group_items_by_key(self, items):
"""按指定键字段聚合项目"""
key_to_items = defaultdict(list)
for item in items:
key = item.get(self.KEY_FIELD, '').strip()
if key:
key_to_items[key].append(item)
return key_to_items
def merge_items_with_llm(self, key, item_list):
"""调用LLM合并具有相同键的项目,失败最多重试三次"""
items = json.dumps(item_list, ensure_ascii=False)
prompt = self.MERGE_PROMPT.format(
name=key,
items=items,
output_format=self.item_parser.get_format_instructions()
)
max_retries = 3
for attempt in range(1, max_retries + 1):
try:
response = self.llm.invoke(prompt, False)
parsed_output = self.item_parser.parse(response.content)
return {"name": parsed_output.name, "synonymous": parsed_output.synonymous, "description": parsed_output.description}
except Exception as e:
if attempt == max_retries:
logging.warning(f"解析LLM合并结果失败: {e}")
return None
else:
import time
time.sleep(5*attempt)
def process_item(self, key_items_tuple):
"""处理单个键值对应的项目,用于线程池并行处理"""
key, item_list = key_items_tuple
try:
if len(item_list) == 1:
return item_list[0]
merged = self.merge_items_with_llm(key, item_list)
if merged:
return merged
else:
# 如果合并失败,返回第一个项目
return item_list[0]
except Exception as e:
logging.error(f"处理键 {key} 时出错: {e}")
return item_list[0]
def deduplicate(self):
"""去重所有项目的入口方法"""
# 1. 读取JSON数据
all_items = self.load_json_data()
if not all_items:
return []
# 2. 按键字段聚合
key_to_items = self.group_items_by_key(all_items)
logging.info(f"{len(key_to_items)}个唯一键")
# 3. 使用线程池并行处理
deduplicated_items = []
items_to_process = []
# 先处理只有一个项目的键(不需要合并)
for key, item_list in key_to_items.items():
if len(item_list) == 1:
deduplicated_items.append(item_list[0])
else:
items_to_process.append((key, item_list))
logging.info(f"{len(deduplicated_items)}个单一项目,{len(items_to_process)}个需要合并的项目")
# 只对需要合并的项目使用线程池处理
if items_to_process:
with ThreadPoolExecutor(max_workers=self.MAX_WORKERS) as executor:
# 使用tqdm显示进度
for result in tqdm(executor.map(self.process_item, items_to_process), total=len(items_to_process)):
deduplicated_items.append(result)
# 4. 保存去重结果
os.makedirs(os.path.dirname(self.OUTPUT_PATH), exist_ok=True)
with open(self.OUTPUT_PATH, 'w', encoding='utf-8') as f:
json.dump(deduplicated_items, f, ensure_ascii=False, indent=2)
logging.info(f"去重后结果已保存到: {self.OUTPUT_PATH}")
return deduplicated_items
def main():
"""主函数,解析命令行参数并执行去重"""
parser = argparse.ArgumentParser(description='对JSON文件进行去重')
input_path = 'data/wiki_extracted_nouns/技改检修计价通_nouns.json'
parser.add_argument('-i', '--input',default=input_path, help='输入JSON文件路径')
parser.add_argument('-o', '--output', help='输出JSON文件路径')
parser.add_argument('-k', '--key', default='name', help='用于去重的键字段名,默认为"name"')
parser.add_argument('-w', '--workers', type=int, default=30, help='线程池最大工作线程数,默认为2')
args = parser.parse_args()
deduplicator = JsonDeduplicator(
input_path=args.input,
output_path=args.output,
key_field=args.key,
max_workers=args.workers
)
deduplicator.deduplicate()
if __name__ == "__main__":
logging.getLogger('httpx').setLevel(logging.WARNING)
logging.getLogger('openai').setLevel(logging.WARNING)
main()
+1 -1
View File
@@ -169,7 +169,7 @@ def main():
cur_path = os.path.dirname(__file__)
input_dir = os.path.abspath(os.path.join(cur_path, '../../data/wiki_extracted_nouns'))
output_path = os.path.join(cur_path, "..", "..", "data", "nouns", 'merged_nouns.json')
merger = TermMerger(input_dir=input_dir, output_path=output_path, max_workers=2)
merger = TermMerger(input_dir=input_dir, output_path=output_path, max_workers=20)
merger.merge()
if __name__ == "__main__":
-1
View File
@@ -41,7 +41,6 @@ class DifyComparisonTester:
Returns:
dict: 包含问题和两个流程回答的字典
"""
q="qwqwwq"
def get_old_answer():
try:
return self.old_chat.create_chat_message(inputs={}, query=q, user="AutoTestDifyChat").json()
@@ -95,6 +95,7 @@ query_rewrite_prompt = """
- 采用【术语标记】规范标注关键概念
- 构建主谓宾明确的问题句式
- 保持原问题时态与语态特征
- 执行同义词替换:将synonymous中的同义词替换为对应name字段的标准术语
# 输出规范
{output_format}
+10
View File
@@ -29,6 +29,16 @@ API_KEY_LIST=[
"sk-pdyymhshpzmdduwxsezthnrgarnnhgzvmiflbpisfzxkiayt",
"sk-qhwoorywmejumyudfxbrkegxtqifsbgcdkmpjckezepgyqnz",
"sk-cpoctrgcnstaybeyuieuwjdgeakudhqdnnwdjavjudcbvvem",
"sk-wqdpapdkisovziexgcyxvumpwzbjnhqbxvcqcspzctjhyhjk",
"sk-bbntrnifrtdzhhgrtlrhvwbnaysuszviemshdakxonnnymnb",
"sk-vmpnwjxersrwybmfhfxgsvbmhsmpjldxseiyxovnysrlbuzi",
"sk-nscsxwfqigkfpfqfzebkmaickxjzbhtfwywdppmmobrrbfnw",
"sk-irbxuakhntsrusrympiubkkjbkabbfbdgpstqnxbztzdtxdq",
"sk-hcfojzczbgwgcuhzxkicxqrhadurtakwbawiesyxyvksmcoz",
"sk-wiyosqgyutjypgzibveiwkgqwfkfsnonrmvjfbvrbkoicciv",
"sk-ocglenyvxkkvzupzumoypnyndjpjqhivyqpedusunboglspz",
"sk-dtbawdwajkhdctrukundbkqwswzfzihqbebfuvqnfnounbuc",
"sk-zqiyiqtbwqgyeenkvppymfbkspriolwbnxnjakugzxyvcuql",
]
class APIKeyManager: