新增异步意图识别器和相关功能,优化意图识别和槽位填充逻辑,支持异步处理和多线程检索,改进API调用的错误处理和日志记录,增强文档检索和关键词提取功能。

This commit is contained in:
2025-07-03 15:40:36 +08:00
parent 68e3677c34
commit c52627abeb
6 changed files with 1146 additions and 40 deletions
@@ -10,11 +10,13 @@ Description: 专业名词向量化和检索的核心逻辑
import os
import json
import shutil
import asyncio
from typing import List, Dict, Any, Tuple, Optional
from langchain.embeddings.base import Embeddings
from langchain_community.vectorstores import FAISS
from rag2_0.tool.ModelTool import SiliconFlowEmbeddings
import logging
import httpx
def get_embedding_model(api_key: str = None) -> Embeddings:
"""
@@ -350,4 +352,148 @@ class ProfessionalNounRetriever:
except Exception as e:
logging.error(f"查询FAISS索引失败: {e}", exc_info=True)
return []
return []
class AsyncProfessionalNounRetriever:
"""异步专业名词检索类"""
def __init__(self,
embedding_model: Optional[Embeddings] = None,
api_key: str = None,
index_dir: str = None):
"""
初始化异步检索器
Args:
embedding_model: 嵌入模型,如果为None则使用默认模型
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
index_dir: 索引目录路径,默认为None使用默认路径
"""
# 设置嵌入模型
if embedding_model:
self.embedding_model = embedding_model
else:
self.embedding_model = get_embedding_model(api_key)
# 设置索引路径
self.index_dir = index_dir
if self.index_dir is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
self.index_dir = os.path.join(current_dir, "..", "..", "data", "nouns", "professional_nouns_index")
# 初始化索引为None,不在构造函数中加载
self.faiss_index = None
@classmethod
async def create(cls,
embedding_model: Optional[Embeddings] = None,
api_key: str = None,
index_dir: str = None):
"""
异步工厂方法:创建并初始化异步检索器实例
Args:
embedding_model: 嵌入模型,如果为None则使用默认模型
api_key: SiliconFlow API密钥,仅在embedding_model为None时使用
index_dir: 索引目录路径,默认为None使用默认路径
Returns:
初始化完成的AsyncProfessionalNounRetriever实例
"""
instance = cls(embedding_model, api_key, index_dir)
await instance._load_index_async()
return instance
async def _load_index_async(self) -> None:
"""
异步从本地加载FAISS索引 (内部方法)
"""
try:
# 由于FAISS加载可能是CPU密集型操作,使用线程池执行器来避免阻塞事件循环
self.faiss_index = await asyncio.to_thread(
FAISS.load_local,
folder_path=self.index_dir,
embeddings=self.embedding_model,
allow_dangerous_deserialization=True
)
logging.info(f"异步成功从 {self.index_dir} 加载FAISS索引")
except Exception as e:
logging.warning(f"异步加载FAISS索引失败: {e}")
self.faiss_index = None
async def _invoke_retriever_async(self, retriever, query_text: str):
"""
异步调用检索器 (内部方法)
Args:
retriever: 检索器实例
query_text: 查询文本
Returns:
检索结果
"""
# 由于LangChain的retriever.invoke可能不是异步的,使用线程池执行器
return await asyncio.to_thread(retriever.invoke, query_text)
async def query_async(self, query_text: str, top_k: int = 5, use_intersection: bool = True) -> List[Dict]:
"""
异步查询FAISS索引,获取最相似的专业名词
Args:
query_text: 查询文本
top_k: 返回的结果数量,默认为5
use_intersection: 是否使用三种检索方式的交集,默认为True
Returns:
相似度最高的专业名词列表
"""
try:
# 检查索引是否已加载
if self.faiss_index is None:
logging.warning("FAISS索引未加载,尝试加载...")
await self._load_index_async()
if self.faiss_index is None:
logging.warning("异步加载FAISS索引失败,无法执行查询")
return []
# 使用三种检索方式并取交集
retriever1 = self.faiss_index.as_retriever(search_kwargs={"k": top_k})
retriever2 = self.faiss_index.as_retriever(
search_type="mmr",
search_kwargs={"k": top_k, "fetch_k": 3, "lambda_mult": 0.5}
)
retriever3 = self.faiss_index.as_retriever(
search_type="similarity_score_threshold",
search_kwargs={"score_threshold": 0.5}
)
# 并行执行三个检索任务
results = await asyncio.gather(
self._invoke_retriever_async(retriever1, query_text),
self._invoke_retriever_async(retriever2, query_text),
self._invoke_retriever_async(retriever3, query_text)
)
# 用json.dumps将dict转为字符串,便于取交集
set1 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False) for i in results[0])
set2 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False) for i in results[1])
set3 = set(json.dumps(i.metadata, sort_keys=True, ensure_ascii=False) for i in results[2])
# 如果use_intersection为True,取交集;否则取并集
if use_intersection:
intersection = set1 & set2 & set3
else:
intersection = set1 | set2 | set3
# 如果交集为空,使用第一种检索方式的结果
if not intersection:
logging.warning("三种检索方式无交集,使用普通检索结果")
return [json.loads(item) for item in set1]
# 转回dict
return [json.loads(item) for item in intersection]
except Exception as e:
logging.error(f"异步查询FAISS索引失败: {e}", exc_info=True)
return []