新增异步意图识别器和相关功能,优化意图识别和槽位填充逻辑,支持异步处理和多线程检索,改进API调用的错误处理和日志记录,增强文档检索和关键词提取功能。
This commit is contained in:
@@ -10,11 +10,13 @@ Description: 专业名词向量化和检索的核心逻辑
|
||||
import os
|
||||
import json
|
||||
import shutil
|
||||
import asyncio
|
||||
from typing import List, Dict, Any, Tuple, Optional
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain_community.vectorstores import FAISS
|
||||
from rag2_0.tool.ModelTool import SiliconFlowEmbeddings
|
||||
import logging
|
||||
import httpx
|
||||
|
||||
def get_embedding_model(api_key: str = None) -> Embeddings:
|
||||
"""
|
||||
@@ -350,4 +352,148 @@ class ProfessionalNounRetriever:
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"查询FAISS索引失败: {e}", exc_info=True)
|
||||
return []
|
||||
return []
|
||||
|
||||
|
||||
class AsyncProfessionalNounRetriever:
|
||||
"""异步专业名词检索类"""
|
||||
|
||||
def __init__(self,
|
||||
embedding_model: Optional[Embeddings] = None,
|
||||
api_key: str = None,
|
||||
index_dir: str = None):
|
||||
"""
|
||||
初始化异步检索器
|
||||
|
||||
Args:
|
||||
embedding_model: 嵌入模型,如果为None则使用默认模型
|
||||
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
|
||||
index_dir: 索引目录路径,默认为None使用默认路径
|
||||
"""
|
||||
# 设置嵌入模型
|
||||
if embedding_model:
|
||||
self.embedding_model = embedding_model
|
||||
else:
|
||||
self.embedding_model = get_embedding_model(api_key)
|
||||
|
||||
# 设置索引路径
|
||||
self.index_dir = index_dir
|
||||
if self.index_dir is None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
self.index_dir = os.path.join(current_dir, "..", "..", "data", "nouns", "professional_nouns_index")
|
||||
|
||||
# 初始化索引为None,不在构造函数中加载
|
||||
self.faiss_index = None
|
||||
|
||||
@classmethod
|
||||
async def create(cls,
|
||||
embedding_model: Optional[Embeddings] = None,
|
||||
api_key: str = None,
|
||||
index_dir: str = None):
|
||||
"""
|
||||
异步工厂方法:创建并初始化异步检索器实例
|
||||
|
||||
Args:
|
||||
embedding_model: 嵌入模型,如果为None则使用默认模型
|
||||
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
|
||||
index_dir: 索引目录路径,默认为None使用默认路径
|
||||
|
||||
Returns:
|
||||
初始化完成的AsyncProfessionalNounRetriever实例
|
||||
"""
|
||||
instance = cls(embedding_model, api_key, index_dir)
|
||||
await instance._load_index_async()
|
||||
return instance
|
||||
|
||||
async def _load_index_async(self) -> None:
|
||||
"""
|
||||
异步从本地加载FAISS索引 (内部方法)
|
||||
"""
|
||||
try:
|
||||
# 由于FAISS加载可能是CPU密集型操作,使用线程池执行器来避免阻塞事件循环
|
||||
self.faiss_index = await asyncio.to_thread(
|
||||
FAISS.load_local,
|
||||
folder_path=self.index_dir,
|
||||
embeddings=self.embedding_model,
|
||||
allow_dangerous_deserialization=True
|
||||
)
|
||||
logging.info(f"异步成功从 {self.index_dir} 加载FAISS索引")
|
||||
except Exception as e:
|
||||
logging.warning(f"异步加载FAISS索引失败: {e}")
|
||||
self.faiss_index = None
|
||||
|
||||
async def _invoke_retriever_async(self, retriever, query_text: str):
|
||||
"""
|
||||
异步调用检索器 (内部方法)
|
||||
|
||||
Args:
|
||||
retriever: 检索器实例
|
||||
query_text: 查询文本
|
||||
|
||||
Returns:
|
||||
检索结果
|
||||
"""
|
||||
# 由于LangChain的retriever.invoke可能不是异步的,使用线程池执行器
|
||||
return await asyncio.to_thread(retriever.invoke, query_text)
|
||||
|
||||
async def query_async(self, query_text: str, top_k: int = 5, use_intersection: bool = True) -> List[Dict]:
|
||||
"""
|
||||
异步查询FAISS索引,获取最相似的专业名词
|
||||
|
||||
Args:
|
||||
query_text: 查询文本
|
||||
top_k: 返回的结果数量,默认为5
|
||||
use_intersection: 是否使用三种检索方式的交集,默认为True
|
||||
|
||||
Returns:
|
||||
相似度最高的专业名词列表
|
||||
"""
|
||||
try:
|
||||
# 检查索引是否已加载
|
||||
if self.faiss_index is None:
|
||||
logging.warning("FAISS索引未加载,尝试加载...")
|
||||
await self._load_index_async()
|
||||
if self.faiss_index is None:
|
||||
logging.warning("异步加载FAISS索引失败,无法执行查询")
|
||||
return []
|
||||
|
||||
# 使用三种检索方式并取交集
|
||||
retriever1 = self.faiss_index.as_retriever(search_kwargs={"k": top_k})
|
||||
retriever2 = self.faiss_index.as_retriever(
|
||||
search_type="mmr",
|
||||
search_kwargs={"k": top_k, "fetch_k": 3, "lambda_mult": 0.5}
|
||||
)
|
||||
retriever3 = self.faiss_index.as_retriever(
|
||||
search_type="similarity_score_threshold",
|
||||
search_kwargs={"score_threshold": 0.5}
|
||||
)
|
||||
|
||||
# 并行执行三个检索任务
|
||||
results = await asyncio.gather(
|
||||
self._invoke_retriever_async(retriever1, query_text),
|
||||
self._invoke_retriever_async(retriever2, query_text),
|
||||
self._invoke_retriever_async(retriever3, query_text)
|
||||
)
|
||||
|
||||
# 用json.dumps将dict转为字符串,便于取交集
|
||||
set1 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False) for i in results[0])
|
||||
set2 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False) for i in results[1])
|
||||
set3 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False) for i in results[2])
|
||||
|
||||
# 如果use_intersection为True,取交集;否则取并集
|
||||
if use_intersection:
|
||||
intersection = set1 & set2 & set3
|
||||
else:
|
||||
intersection = set1 | set2 | set3
|
||||
|
||||
# 如果交集为空,使用第一种检索方式的结果
|
||||
if not intersection:
|
||||
logging.warning("三种检索方式无交集,使用普通检索结果")
|
||||
return [json.loads(item) for item in set1]
|
||||
|
||||
# 转回dict
|
||||
return [json.loads(item) for item in intersection]
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"异步查询FAISS索引失败: {e}", exc_info=True)
|
||||
return []
|
||||
|
||||
Reference in New Issue
Block a user