From 043aea6cca895d88daa1ef4e5033b6eb4164985b Mon Sep 17 00:00:00 2001 From: wanyaokun <12345678> Date: Thu, 22 Aug 2024 11:06:22 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E8=87=AA=E5=AE=9A=E4=B9=89?= =?UTF-8?q?=E5=85=B3=E9=94=AE=E8=AF=8D=E6=A3=80=E7=B4=A2=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/engine/__init__.py | 14 ++++++++------ backend/app/engine/generate.py | 9 +++++++++ 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/backend/app/engine/__init__.py b/backend/app/engine/__init__.py index fc36c14..12a45fc 100644 --- a/backend/app/engine/__init__.py +++ b/backend/app/engine/__init__.py @@ -11,6 +11,7 @@ from sqlalchemy import create_engine, Engine from app.engine.loaders.db import makeDescriptionByEngine from app.engine.tools import ToolFactory from app.engine.index import get_index +from app.engine.retriever.CHBM25Retriever import CHBM25Retriever from app.settings import get_node_postprocessors from llama_index.core.retrievers import BaseRetriever @@ -29,9 +30,6 @@ class HybridRetriever(BaseRetriever): filters = None, **kwargs: Any, ) -> None: - from llama_index.retrievers.bm25 import BM25Retriever - from nltk.corpus import stopwords - super().__init__(**kwargs) self._vector_index = vector_index self._embed_model = vector_index._embed_model @@ -39,9 +37,13 @@ class HybridRetriever(BaseRetriever): self._vecRetriever = vector_index.as_retriever( similarity_top_k=similarity_top_k,filters = filters ) - self._bm25Retriever = BM25Retriever.from_defaults(similarity_top_k=similarity_top_k, - nodes=self._vector_index.vector_store.get_nodes(None), - language=stopwords.words('chinese')) + + STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm") + if os.path.exists(STORAGE_DIR) and len(os.listdir(STORAGE_DIR)) > 0: + self._bm25Retriever = CHBM25Retriever.from_persist_dir(STORAGE_DIR) + else: + bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=similarity_top_k,nodes=self._vector_index.vector_store.get_nodes(None)) + bmRetriver.persist(STORAGE_DIR) self._alpha = alpha def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: diff --git a/backend/app/engine/generate.py b/backend/app/engine/generate.py index 115c175..87ecfa1 100644 --- a/backend/app/engine/generate.py +++ b/backend/app/engine/generate.py @@ -8,6 +8,7 @@ import os from app.engine.loaders import get_documents from app.engine.vectordb import get_vector_store from app.settings import init_settings +from app.engine.retriever.CHBM25Retriever import CHBM25Retriever from llama_index.core.ingestion import IngestionPipeline from llama_index.core.node_parser import SentenceSplitter from llama_index.core.settings import Settings @@ -58,6 +59,13 @@ def persist_storage(docstore, vector_store): storage_context.persist(STORAGE_DIR) +def persist_BMRetriever(vector_store): + STORAGE_DIR = os.getenv("BM_RETRIEVER_PATH", "storage_bm") + top_k = int(os.getenv("TOP_K", "3")) + bmRetriver = CHBM25Retriever.from_defaults(similarity_top_k=top_k,nodes=vector_store.get_nodes([])) + bmRetriver.persist(STORAGE_DIR) + + def generate_datasource(): init_settings() logger.info("Generate index for the provided data") @@ -75,6 +83,7 @@ def generate_datasource(): # Build the index and persist storage persist_storage(docstore, vector_store) + persist_BMRetriever(vector_store) logger.info("Finished generating the index")