新增自定义关键词检索类

This commit is contained in:
wanyaokun
2024-08-22 11:06:22 +08:00
parent f5d6eb6a22
commit 043aea6cca
2 changed files with 17 additions and 6 deletions
+8 -6
View File
@@ -11,6 +11,7 @@ from sqlalchemy import create_engine, Engine
from app.engine.loaders.db import makeDescriptionByEngine from app.engine.loaders.db import makeDescriptionByEngine
from app.engine.tools import ToolFactory from app.engine.tools import ToolFactory
from app.engine.index import get_index from app.engine.index import get_index
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
from app.settings import get_node_postprocessors from app.settings import get_node_postprocessors
from llama_index.core.retrievers import BaseRetriever from llama_index.core.retrievers import BaseRetriever
@@ -29,9 +30,6 @@ class HybridRetriever(BaseRetriever):
filters = None, filters = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
from llama_index.retrievers.bm25 import BM25Retriever
from nltk.corpus import stopwords
super().__init__(**kwargs) super().__init__(**kwargs)
self._vector_index = vector_index self._vector_index = vector_index
self._embed_model = vector_index._embed_model self._embed_model = vector_index._embed_model
@@ -39,9 +37,13 @@ class HybridRetriever(BaseRetriever):
self._vecRetriever = vector_index.as_retriever( self._vecRetriever = vector_index.as_retriever(
similarity_top_k=similarity_top_k,filters = filters similarity_top_k=similarity_top_k,filters = filters
) )
self._bm25Retriever = BM25Retriever.from_defaults(similarity_top_k=similarity_top_k,
nodes=self._vector_index.vector_store.get_nodes(None), STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
language=stopwords.words('chinese')) if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
else:
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None))
bmRetriver.persist(STORAGE_DIR)
self._alpha = alpha self._alpha = alpha
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
+9
View File
@@ -8,6 +8,7 @@ import os
from app.engine.loaders import get_documents from app.engine.loaders import get_documents
from app.engine.vectordb import get_vector_store from app.engine.vectordb import get_vector_store
from app.settings import init_settings from app.settings import init_settings
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
from llama_index.core.ingestion import IngestionPipeline from llama_index.core.ingestion import IngestionPipeline
from llama_index.core.node_parser import SentenceSplitter from llama_index.core.node_parser import SentenceSplitter
from llama_index.core.settings import Settings from llama_index.core.settings import Settings
@@ -58,6 +59,13 @@ def persist_storage(docstore, vector_store):
storage_context.persist(STORAGE_DIR) storage_context.persist(STORAGE_DIR)
def persist_BMRetriever(vector_store):
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
top_k = int(os.getenv("TOP_K", "3"))
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=top_k,nodes=vector_store.get_nodes([]))
bmRetriver.persist(STORAGE_DIR)
def generate_datasource(): def generate_datasource():
init_settings() init_settings()
logger.info("Generate index for the provided data") logger.info("Generate index for the provided data")
@@ -75,6 +83,7 @@ def generate_datasource():
# Build the index and persist storage # Build the index and persist storage
persist_storage(docstore, vector_store) persist_storage(docstore, vector_store)
persist_BMRetriever(vector_store)
logger.info("Finished generating the index") logger.info("Finished generating the index")