321 lines
11 KiB
Python
321 lines
11 KiB
Python
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
File: ProfessionalNounVector.py
|
|
Date: 2025-05-15
|
|
Author: oyyz
|
|
Description: 专业名词向量化和检索的核心逻辑
|
|
"""
|
|
|
|
import os
|
|
import json
|
|
import shutil
|
|
from typing import List, Dict, Any, Tuple, Optional
|
|
from langchain.embeddings.base import Embeddings
|
|
from langchain_community.vectorstores import FAISS
|
|
from rag2_0.tool.ModelTool import SiliconFlowEmbeddings
|
|
import logging
|
|
|
|
def get_embedding_model(api_key: str = None) -> Embeddings:
|
|
"""
|
|
获取嵌入模型
|
|
|
|
Args:
|
|
api_key: API密钥,如果为None则从环境变量获取
|
|
|
|
Returns:
|
|
嵌入模型实例
|
|
"""
|
|
if not api_key:
|
|
api_key = os.getenv("SILICONFLOW_API_KEY", "sk-ftnofbucchwnscojohyxwmfzgaykdxihafnlphohsinftkbr")
|
|
return SiliconFlowEmbeddings(api_key=api_key)
|
|
|
|
|
|
class ProfessionalNounVectorizer:
|
|
"""专业名词向量化和保存类"""
|
|
|
|
def __init__(self,
|
|
embedding_model: Optional[Embeddings] = None,
|
|
api_key: str = None,
|
|
output_dir: str = None):
|
|
"""
|
|
初始化向量化器
|
|
|
|
Args:
|
|
embedding_model: 嵌入模型,如果为None则使用默认模型
|
|
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
|
|
|
|
output_dir: 索引输出目录,默认为None使用默认路径
|
|
"""
|
|
# 设置嵌入模型
|
|
if embedding_model:
|
|
self.embedding_model = embedding_model
|
|
else:
|
|
self.embedding_model = get_embedding_model(api_key)
|
|
|
|
|
|
# 设置输出目录
|
|
self.output_dir = output_dir
|
|
if self.output_dir is None:
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
self.output_dir = os.path.join(current_dir, "..", "..", "data", "nouns")
|
|
|
|
# 设置索引路径
|
|
self.index_path = os.path.join(self.output_dir, "professional_nouns_index")
|
|
|
|
def _loadfile(self, file_paths: List[str]) -> List[Dict[str, Any]]:
|
|
"""
|
|
加载多个专业术语JSON文件并合并
|
|
|
|
Args:
|
|
file_paths: JSON文件路径列表
|
|
|
|
Returns:
|
|
合并后的术语列表
|
|
"""
|
|
merged_terms = []
|
|
|
|
try:
|
|
for file_path in file_paths:
|
|
if not os.path.exists(file_path):
|
|
logging.warning(f"文件不存在: {file_path}")
|
|
continue
|
|
|
|
with open(file_path, "r", encoding="utf-8") as f:
|
|
terms_data = json.load(f)
|
|
|
|
if isinstance(terms_data, list):
|
|
merged_terms.extend(terms_data)
|
|
logging.info(f"从 {file_path} 加载了 {len(terms_data)} 条专业名词")
|
|
else:
|
|
logging.warning(f"文件格式错误: {file_path},应为JSON数组")
|
|
|
|
logging.info(f"总共加载了 {len(merged_terms)} 条专业名词")
|
|
return merged_terms
|
|
except Exception as e:
|
|
logging.error(f"加载多个文件失败: {e}")
|
|
return []
|
|
|
|
def vectorize_files_and_save(self, file_paths: List[str]) -> bool:
|
|
"""
|
|
处理多个文件:加载多个术语文件、创建索引并保存
|
|
|
|
Args:
|
|
file_paths: JSON文件路径列表
|
|
|
|
Returns:
|
|
处理成功返回True,否则返回False
|
|
"""
|
|
try:
|
|
# 加载多个文件的术语
|
|
terms = self._loadfile(file_paths)
|
|
|
|
if not terms:
|
|
logging.warning("未找到术语数据,退出处理")
|
|
return False
|
|
|
|
# 根据名称去重
|
|
unique_terms = {}
|
|
for term in terms:
|
|
name = term.get("name", "")
|
|
if name and name not in unique_terms:
|
|
unique_terms[name] = term
|
|
|
|
# 转换回列表
|
|
deduplicated_terms = list(unique_terms.values())
|
|
logging.info(f"去重后剩余 {len(deduplicated_terms)} 条专业名词")
|
|
|
|
# 准备数据
|
|
texts, metadatas = self._prepare_terms_for_faiss(deduplicated_terms)
|
|
|
|
# 创建索引
|
|
faiss_index = self._create_index(texts, metadatas)
|
|
|
|
# 保存索引
|
|
self._save_index(faiss_index)
|
|
|
|
logging.info("完成多文件专业名词向量化和索引创建")
|
|
return True
|
|
except Exception as e:
|
|
logging.error(f"多文件向量化处理失败: {e}")
|
|
return False
|
|
|
|
|
|
def _prepare_terms_for_faiss(self, terms: List[Dict[str, Any]]) -> Tuple[List[str], List[Dict]]:
|
|
"""
|
|
将术语准备为FAISS可用的格式 (内部方法)
|
|
|
|
Args:
|
|
terms: 术语列表
|
|
|
|
Returns:
|
|
格式化的术语文本列表和元数据列表
|
|
"""
|
|
texts = []
|
|
metadatas = []
|
|
|
|
for term in terms:
|
|
name = term["name"]
|
|
texts.append(name.strip())
|
|
synonyms = term.get("synonymous", [])
|
|
description = term.get("description", "")
|
|
# 记录元数据
|
|
metadatas.append({
|
|
"name": name,
|
|
"synonyms": synonyms,
|
|
"description": description
|
|
})
|
|
|
|
if len(synonyms) > 0:
|
|
synonyms_str = ', '.join(synonyms)
|
|
texts.append(synonyms_str.strip())
|
|
metadatas.append({
|
|
"name": name,
|
|
"synonyms": synonyms,
|
|
"description": description
|
|
})
|
|
|
|
if len(description) > 0:
|
|
texts.append(description.strip())
|
|
metadatas.append({
|
|
"name": name,
|
|
"synonyms": synonyms,
|
|
"description": description
|
|
})
|
|
|
|
return texts, metadatas
|
|
|
|
def _create_index(self, texts: List[str], metadatas: List[Dict]) -> FAISS:
|
|
"""
|
|
创建FAISS索引 (内部方法)
|
|
|
|
Args:
|
|
texts: 文本列表
|
|
metadatas: 元数据列表
|
|
|
|
Returns:
|
|
FAISS索引
|
|
"""
|
|
logging.info(f"正在创建FAISS索引,共 {len(texts)} 条数据...")
|
|
return FAISS.from_texts(texts=texts, embedding=self.embedding_model, metadatas=metadatas)
|
|
|
|
def _save_index(self, faiss_index: FAISS) -> None:
|
|
"""
|
|
保存FAISS索引到本地 (内部方法)
|
|
|
|
Args:
|
|
faiss_index: 要保存的FAISS索引
|
|
"""
|
|
try:
|
|
# 确保输出目录存在
|
|
os.makedirs(self.output_dir, exist_ok=True)
|
|
|
|
# 如果索引目录已存在,先删除
|
|
if os.path.exists(self.index_path):
|
|
shutil.rmtree(self.index_path)
|
|
|
|
# 保存FAISS索引
|
|
faiss_index.save_local(folder_path=self.index_path)
|
|
logging.info(f"FAISS索引已保存至 {self.index_path}")
|
|
except Exception as e:
|
|
logging.error(f"保存FAISS索引失败: {e}")
|
|
raise e
|
|
|
|
|
|
class ProfessionalNounRetriever:
|
|
"""专业名词检索类"""
|
|
|
|
def __init__(self,
|
|
embedding_model: Optional[Embeddings] = None,
|
|
api_key: str = None,
|
|
index_dir: str = None):
|
|
"""
|
|
初始化检索器并加载索引
|
|
|
|
Args:
|
|
embedding_model: 嵌入模型,如果为None则使用默认模型
|
|
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
|
|
index_dir: 索引目录路径,默认为None使用默认路径
|
|
"""
|
|
# 设置嵌入模型
|
|
if embedding_model:
|
|
self.embedding_model = embedding_model
|
|
else:
|
|
self.embedding_model = get_embedding_model(api_key)
|
|
|
|
# 设置索引路径
|
|
self.index_dir = index_dir
|
|
if self.index_dir is None:
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
self.index_dir = os.path.join(current_dir, "..", "..", "data", "nouns", "professional_nouns_index")
|
|
|
|
# 在构造函数中加载索引
|
|
self.faiss_index = None
|
|
self._load_index()
|
|
|
|
def _load_index(self) -> None:
|
|
"""
|
|
从本地加载FAISS索引 (内部方法)
|
|
"""
|
|
try:
|
|
# 加载FAISS索引,启用不安全反序列化(仅用于可信数据源)
|
|
self.faiss_index = FAISS.load_local(
|
|
folder_path=self.index_dir,
|
|
embeddings=self.embedding_model,
|
|
allow_dangerous_deserialization=True
|
|
)
|
|
logging.info(f"成功从 {self.index_dir} 加载FAISS索引")
|
|
except Exception as e:
|
|
logging.warning(f"加载FAISS索引失败: {e}")
|
|
self.faiss_index = None
|
|
|
|
def query(self, query_text: str, top_k: int = 5, use_intersection: bool = True) -> List[Dict]:
|
|
"""
|
|
查询FAISS索引,获取最相似的专业名词 (唯一对外接口)
|
|
|
|
Args:
|
|
query_text: 查询文本
|
|
top_k: 返回的结果数量,默认为5
|
|
use_intersection: 是否使用三种检索方式的交集,默认为True
|
|
|
|
Returns:
|
|
相似度最高的专业名词列表
|
|
"""
|
|
try:
|
|
# 检查索引是否已加载
|
|
if self.faiss_index is None:
|
|
logging.warning("FAISS索引未加载,无法执行查询")
|
|
return []
|
|
|
|
# 使用三种检索方式并取交集
|
|
retriever1 = self.faiss_index.as_retriever(search_kwargs={"k": top_k})
|
|
retriever2 = self.faiss_index.as_retriever(
|
|
search_type="mmr",
|
|
search_kwargs={"k": top_k, "fetch_k": 3, "lambda_mult": 0.5}
|
|
)
|
|
retriever3 = self.faiss_index.as_retriever(
|
|
search_type="similarity_score_threshold",
|
|
search_kwargs={"score_threshold": 0.5}
|
|
)
|
|
|
|
# 用json.dumps将dict转为字符串,便于取交集
|
|
set1 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
|
for i in retriever1.invoke(query_text))
|
|
set2 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
|
for i in retriever2.invoke(query_text))
|
|
set3 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False)
|
|
for i in retriever3.invoke(query_text))
|
|
|
|
intersection = set1 | set2 | set3
|
|
|
|
# 如果交集为空,使用第一种检索方式的结果
|
|
if not intersection:
|
|
logging.warning("三种检索方式无交集,使用普通检索结果")
|
|
return [json.loads(item) for item in set1]
|
|
|
|
# 转回dict
|
|
return [json.loads(item) for item in intersection]
|
|
|
|
except Exception as e:
|
|
logging.error(f"查询FAISS索引失败: {e}")
|
|
return [] |