From 8d7190d0b6d0451b964a84ff9eab56d51989da63 Mon Sep 17 00:00:00 2001 From: wanyaokun <12345678> Date: Thu, 22 Aug 2024 11:07:23 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9E=E5=85=B3=E9=94=AE=E5=AD=97?= =?UTF-8?q?=E6=A3=80=E7=B4=A2=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../app/engine/retriever/CHBM25Retriever.py | 133 ++++++++++++++++++ backend/app/engine/retriever/CHTokener.py | 46 ++++++ 2 files changed, 179 insertions(+) create mode 100644 backend/app/engine/retriever/CHBM25Retriever.py create mode 100644 backend/app/engine/retriever/CHTokener.py diff --git a/backend/app/engine/retriever/CHBM25Retriever.py b/backend/app/engine/retriever/CHBM25Retriever.py new file mode 100644 index 0000000..fa5d5ec --- /dev/null +++ b/backend/app/engine/retriever/CHBM25Retriever.py @@ -0,0 +1,133 @@ +import json +import logging +import os + +from typing import Any, Callable, Dict, List, Optional, cast + +from llama_index.core.base.base_retriever import BaseRetriever +from llama_index.core.callbacks.base import CallbackManager +from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K +from llama_index.core.indices.vector_store.base import VectorStoreIndex +from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle +from llama_index.core.storage.docstore.types import BaseDocumentStore +from llama_index.core.vector_stores.utils import ( + node_to_metadata_dict, + metadata_dict_to_node, +) + +import bm25s +from app.engine.retriever.CHTokener import chTokenize + +CHDEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"} + +CHDEFAULT_PERSIST_FILENAME = "retriever.json" + +class CHBM25Retriever(BaseRetriever): + def __init__( + self, + nodes: Optional[List[BaseNode]] = None, + existing_bm25: Optional[bm25s.BM25] = None, + similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + callback_manager: Optional[CallbackManager] = None, + objects: Optional[List[IndexNode]] = None, + object_map: Optional[dict] = None, + verbose: bool = False, + ) -> None: + self.similarity_top_k = similarity_top_k + if existing_bm25 is not None: + self.bm25 = existing_bm25 + self.corpus = existing_bm25.corpus + else: + from nltk.corpus import stopwords + if nodes is None: + raise ValueError("Please pass nodes or an existing BM25 object.") + + self.corpus = [node_to_metadata_dict(node) for node in nodes] + + corpus_tokens = chTokenize( + [node.get_content() for node in nodes], + show_progress=verbose, + ) + self.bm25 = bm25s.BM25() + self.bm25.index(corpus_tokens, show_progress=verbose) + super().__init__( + callback_manager=callback_manager, + object_map=object_map, + objects=objects, + verbose=verbose, + ) + + @classmethod + def from_defaults( + cls, + index: Optional[VectorStoreIndex] = None, + nodes: Optional[List[BaseNode]] = None, + docstore: Optional[BaseDocumentStore] = None, + similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K, + verbose: bool = False, + ) -> "CHBM25Retriever": + if sum(bool(val) for val in [index, nodes, docstore]) != 1: + raise ValueError("Please pass exactly one of index, nodes, or docstore.") + + if index is not None: + docstore = index.docstore + + if docstore is not None: + nodes = cast(List[BaseNode], list(docstore.docs.values())) + + assert ( + nodes is not None + ), "Please pass exactly one of index, nodes, or docstore." + + return cls( + nodes=nodes, + similarity_top_k=similarity_top_k, + verbose=verbose, + ) + + def get_persist_args(self) -> Dict[str, Any]: + """Get Persist Args Dict to Save.""" + return { + CHDEFAULT_PERSIST_ARGS[key]: getattr(self, key) + for key in CHDEFAULT_PERSIST_ARGS + if hasattr(self, key) + } + + def persist(self, path: str, **kwargs: Any) -> None: + """Persist the retriever to a directory.""" + self.bm25.save(path, corpus=self.corpus, **kwargs) + with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME), "w") as f: + json.dump(self.get_persist_args(), f, indent=2) + + @classmethod + def from_persist_dir(cls, path: str, **kwargs: Any) -> "CHBM25Retriever": + """Load the retriever from a directory.""" + bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs) + with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME)) as f: + retriever_data = json.load(f) + return cls(existing_bm25=bm25, **retriever_data) + + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + query = query_bundle.query_str + tokenized_query = chTokenize( + query,show_progress=self._verbose + ) + indexes, scores = self.bm25.retrieve( + tokenized_query, k=self.similarity_top_k, show_progress=self._verbose + ) + + # batched, but only one query + indexes = indexes[0] + scores = scores[0] + + nodes: List[NodeWithScore] = [] + for idx, score in zip(indexes, scores): + # idx can be an int or a dict of the node + if isinstance(idx, dict): + node = metadata_dict_to_node(idx) + else: + node_dict = self.corpus[int(idx)] + node = metadata_dict_to_node(node_dict) + nodes.append(NodeWithScore(node=node, score=float(score))) + + return nodes \ No newline at end of file diff --git a/backend/app/engine/retriever/CHTokener.py b/backend/app/engine/retriever/CHTokener.py new file mode 100644 index 0000000..9c5a071 --- /dev/null +++ b/backend/app/engine/retriever/CHTokener.py @@ -0,0 +1,46 @@ +from typing import Any, Dict, List, Union, Callable, NamedTuple +from bm25s.tokenization import * + +try: + from tqdm.auto import tqdm +except ImportError: + + def tqdm(iterable, *args, **kwargs): + return iterable + + +def chinese_tokenizer(text: str) -> List[str]: + import jieba + from nltk.corpus import stopwords + tokens = jieba.lcut(text) + return [token for token in tokens if token not in stopwords.words('chinese')] + +def chTokenize( + texts, + show_progress: bool = True, + leave: bool = False, +) -> Union[List[List[str]], Tokenized]: + if isinstance(texts, str): + texts = [texts] + + corpus_ids = [] + token_to_index = {} + + for text in tqdm( + texts, desc="Split strings", leave=leave, disable=not show_progress + ): + + splitted = chinese_tokenizer(text) + doc_ids = [] + + for token in splitted: + if token not in token_to_index: + token_to_index[token] = len(token_to_index) + + token_id = token_to_index[token] + doc_ids.append(token_id) + + corpus_ids.append(doc_ids) + + return Tokenized(ids=corpus_ids, vocab=token_to_index) +