dev #2
@@ -11,6 +11,7 @@ from sqlalchemy import create_engine, Engine
|
|||||||
from app.engine.loaders.db import makeDescriptionByEngine
|
from app.engine.loaders.db import makeDescriptionByEngine
|
||||||
from app.engine.tools import ToolFactory
|
from app.engine.tools import ToolFactory
|
||||||
from app.engine.index import get_index
|
from app.engine.index import get_index
|
||||||
|
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||||
from app.settings import get_node_postprocessors
|
from app.settings import get_node_postprocessors
|
||||||
|
|
||||||
from llama_index.core.retrievers import BaseRetriever
|
from llama_index.core.retrievers import BaseRetriever
|
||||||
@@ -29,9 +30,6 @@ class HybridRetriever(BaseRetriever):
|
|||||||
filters = None,
|
filters = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> None:
|
) -> None:
|
||||||
from llama_index.retrievers.bm25 import BM25Retriever
|
|
||||||
from nltk.corpus import stopwords
|
|
||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self._vector_index = vector_index
|
self._vector_index = vector_index
|
||||||
self._embed_model = vector_index._embed_model
|
self._embed_model = vector_index._embed_model
|
||||||
@@ -39,9 +37,13 @@ class HybridRetriever(BaseRetriever):
|
|||||||
self._vecRetriever = vector_index.as_retriever(
|
self._vecRetriever = vector_index.as_retriever(
|
||||||
similarity_top_k=similarity_top_k,filters = filters
|
similarity_top_k=similarity_top_k,filters = filters
|
||||||
)
|
)
|
||||||
self._bm25Retriever = BM25Retriever.from_defaults(similarity_top_k=similarity_top_k,
|
|
||||||
nodes=self._vector_index.vector_store.get_nodes(None),
|
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||||
language=stopwords.words('chinese'))
|
if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0:
|
||||||
|
self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR)
|
||||||
|
else:
|
||||||
|
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None))
|
||||||
|
bmRetriver.persist(STORAGE_DIR)
|
||||||
self._alpha = alpha
|
self._alpha = alpha
|
||||||
|
|
||||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import os
|
|||||||
from app.engine.loaders import get_documents
|
from app.engine.loaders import get_documents
|
||||||
from app.engine.vectordb import get_vector_store
|
from app.engine.vectordb import get_vector_store
|
||||||
from app.settings import init_settings
|
from app.settings import init_settings
|
||||||
|
from app.engine.retriever.CHBM25Retriever import CHBM25Retriever
|
||||||
from llama_index.core.ingestion import IngestionPipeline
|
from llama_index.core.ingestion import IngestionPipeline
|
||||||
from llama_index.core.node_parser import SentenceSplitter
|
from llama_index.core.node_parser import SentenceSplitter
|
||||||
from llama_index.core.settings import Settings
|
from llama_index.core.settings import Settings
|
||||||
@@ -58,6 +59,13 @@ def persist_storage(docstore, vector_store):
|
|||||||
storage_context.persist(STORAGE_DIR)
|
storage_context.persist(STORAGE_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def persist_BMRetriever(vector_store):
|
||||||
|
STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm")
|
||||||
|
top_k = int(os.getenv("TOP_K", "3"))
|
||||||
|
bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=top_k,nodes=vector_store.get_nodes([]))
|
||||||
|
bmRetriver.persist(STORAGE_DIR)
|
||||||
|
|
||||||
|
|
||||||
def generate_datasource():
|
def generate_datasource():
|
||||||
init_settings()
|
init_settings()
|
||||||
logger.info("Generate index for the provided data")
|
logger.info("Generate index for the provided data")
|
||||||
@@ -75,6 +83,7 @@ def generate_datasource():
|
|||||||
|
|
||||||
# Build the index and persist storage
|
# Build the index and persist storage
|
||||||
persist_storage(docstore, vector_store)
|
persist_storage(docstore, vector_store)
|
||||||
|
persist_BMRetriever(vector_store)
|
||||||
|
|
||||||
logger.info("Finished generating the index")
|
logger.info("Finished generating the index")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user