#!/usr/bin/env python # -*- coding: utf-8 -*- """ File: ProfessionalNounVector.py Date: 2025-05-15 Author: oyyz Description: 专业名词向量化和检索的核心逻辑 """ import os import json import shutil from typing import List, Dict, Any, Tuple, Optional from langchain.embeddings.base import Embeddings from langchain_community.vectorstores import FAISS from rag2_0.tool.ModelTool import SiliconFlowEmbeddings import logging def get_embedding_model(api_key: str = None) -> Embeddings: """ 获取嵌入模型 Args: api_key: API密钥,如果为None则从环境变量获取 Returns: 嵌入模型实例 """ if not api_key: api_key = os.getenv("SILICONFLOW_API_KEY", "sk-ftnofbucchwnscojohyxwmfzgaykdxihafnlphohsinftkbr") return SiliconFlowEmbeddings(api_key=api_key) class ProfessionalNounVectorizer: """专业名词向量化和保存类""" def __init__(self, embedding_model: Optional[Embeddings] = None, api_key: str = None, output_dir: str = None): """ 初始化向量化器 Args: embedding_model: 嵌入模型,如果为None则使用默认模型 api_key: SiliconFlow API密钥,仅在embedding_model为None时使用 output_dir: 索引输出目录,默认为None使用默认路径 """ # 设置嵌入模型 if embedding_model: self.embedding_model = embedding_model else: self.embedding_model = get_embedding_model(api_key) # 设置输出目录 self.output_dir = output_dir if self.output_dir is None: current_dir = os.path.dirname(os.path.abspath(__file__)) self.output_dir = os.path.join(current_dir, "..", "..", "data", "nouns") # 设置索引路径 self.index_path = os.path.join(self.output_dir, "professional_nouns_index") def _loadfile(self, file_paths: List[str]) -> List[Dict[str, Any]]: """ 加载多个专业术语JSON文件并合并 Args: file_paths: JSON文件路径列表 Returns: 合并后的术语列表 """ merged_terms = [] try: for file_path in file_paths: if not os.path.exists(file_path): logging.warning(f"文件不存在: {file_path}") continue with open(file_path, "r", encoding="utf-8") as f: terms_data = json.load(f) if isinstance(terms_data, list): merged_terms.extend(terms_data) logging.info(f"从 {file_path} 加载了 {len(terms_data)} 条专业名词") else: logging.warning(f"文件格式错误: {file_path},应为JSON数组") logging.info(f"总共加载了 {len(merged_terms)} 条专业名词") return merged_terms except Exception as e: logging.error(f"加载多个文件失败: {e}") return [] def vectorize_files_and_save(self, file_paths: List[str]) -> bool: """ 处理多个文件:加载多个术语文件、创建索引并保存 Args: file_paths: JSON文件路径列表 Returns: 处理成功返回True,否则返回False """ try: # 加载多个文件的术语 terms = self._loadfile(file_paths) if not terms: logging.warning("未找到术语数据,退出处理") return False # 根据名称去重 unique_terms = {} for term in terms: name = term.get("name", "") if name and name not in unique_terms: unique_terms[name] = term # 转换回列表 deduplicated_terms = list(unique_terms.values()) logging.info(f"去重后剩余 {len(deduplicated_terms)} 条专业名词") # 准备数据 texts, metadatas = self._prepare_terms_for_faiss(deduplicated_terms) # 创建索引 faiss_index = self._create_index(texts, metadatas) # 保存索引 self._save_index(faiss_index) logging.info("完成多文件专业名词向量化和索引创建") return True except Exception as e: logging.error(f"多文件向量化处理失败: {e}") return False def _prepare_terms_for_faiss(self, terms: List[Dict[str, Any]]) -> Tuple[List[str], List[Dict]]: """ 将术语准备为FAISS可用的格式 (内部方法) Args: terms: 术语列表 Returns: 格式化的术语文本列表和元数据列表 """ texts = [] metadatas = [] for term in terms: name = term["name"] texts.append(name.strip()) synonyms = term.get("synonymous", []) description = term.get("description", "") # 记录元数据 metadatas.append({ "name": name, "synonyms": synonyms, "description": description }) if len(synonyms) > 0: synonyms_str = ', '.join(synonyms) texts.append(synonyms_str.strip()) metadatas.append({ "name": name, "synonyms": synonyms, "description": description }) if len(description) > 0: texts.append(description.strip()) metadatas.append({ "name": name, "synonyms": synonyms, "description": description }) return texts, metadatas def _create_index(self, texts: List[str], metadatas: List[Dict]) -> FAISS: """ 创建FAISS索引 (内部方法) Args: texts: 文本列表 metadatas: 元数据列表 Returns: FAISS索引 """ logging.info(f"正在创建FAISS索引,共 {len(texts)} 条数据...") return FAISS.from_texts(texts=texts, embedding=self.embedding_model, metadatas=metadatas) def _save_index(self, faiss_index: FAISS) -> None: """ 保存FAISS索引到本地 (内部方法) Args: faiss_index: 要保存的FAISS索引 """ try: # 确保输出目录存在 os.makedirs(self.output_dir, exist_ok=True) # 如果索引目录已存在,先删除 if os.path.exists(self.index_path): shutil.rmtree(self.index_path) # 保存FAISS索引 faiss_index.save_local(folder_path=self.index_path) logging.info(f"FAISS索引已保存至 {self.index_path}") except Exception as e: logging.error(f"保存FAISS索引失败: {e}") raise e class ProfessionalNounRetriever: """专业名词检索类""" def __init__(self, embedding_model: Optional[Embeddings] = None, api_key: str = None, index_dir: str = None): """ 初始化检索器并加载索引 Args: embedding_model: 嵌入模型,如果为None则使用默认模型 api_key: SiliconFlow API密钥,仅在embedding_model为None时使用 index_dir: 索引目录路径,默认为None使用默认路径 """ # 设置嵌入模型 if embedding_model: self.embedding_model = embedding_model else: self.embedding_model = get_embedding_model(api_key) # 设置索引路径 self.index_dir = index_dir if self.index_dir is None: current_dir = os.path.dirname(os.path.abspath(__file__)) self.index_dir = os.path.join(current_dir, "..", "..", "data", "nouns", "professional_nouns_index") # 在构造函数中加载索引 self.faiss_index = None self._load_index() def _load_index(self) -> None: """ 从本地加载FAISS索引 (内部方法) """ try: # 加载FAISS索引,启用不安全反序列化(仅用于可信数据源) self.faiss_index = FAISS.load_local( folder_path=self.index_dir, embeddings=self.embedding_model, allow_dangerous_deserialization=True ) logging.info(f"成功从 {self.index_dir} 加载FAISS索引") except Exception as e: logging.warning(f"加载FAISS索引失败: {e}") self.faiss_index = None def query(self, query_text: str, top_k: int = 5, use_intersection: bool = True) -> List[Dict]: """ 查询FAISS索引,获取最相似的专业名词 (唯一对外接口) Args: query_text: 查询文本 top_k: 返回的结果数量,默认为5 use_intersection: 是否使用三种检索方式的交集,默认为True Returns: 相似度最高的专业名词列表 """ try: # 检查索引是否已加载 if self.faiss_index is None: logging.warning("FAISS索引未加载,无法执行查询") return [] # 使用三种检索方式并取交集 retriever1 = self.faiss_index.as_retriever(search_kwargs={"k": top_k}) retriever2 = self.faiss_index.as_retriever( search_type="mmr", search_kwargs={"k": top_k, "fetch_k": 3, "lambda_mult": 0.5} ) retriever3 = self.faiss_index.as_retriever( search_type="similarity_score_threshold", search_kwargs={"score_threshold": 0.5} ) # 用json.dumps将dict转为字符串,便于取交集 set1 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False) for i in retriever1.invoke(query_text)) set2 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False) for i in retriever2.invoke(query_text)) set3 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False) for i in retriever3.invoke(query_text)) intersection = set1 | set2 | set3 # 如果交集为空,使用第一种检索方式的结果 if not intersection: logging.warning("三种检索方式无交集,使用普通检索结果") return [json.loads(item) for item in set1] # 转回dict return [json.loads(item) for item in intersection] except Exception as e: logging.error(f"查询FAISS索引失败: {e}") return []