133 lines
4.7 KiB
Python
133 lines
4.7 KiB
Python
import json
|
|
import logging
|
|
import os
|
|
|
|
from typing import Any, Callable, Dict, List, Optional, cast
|
|
|
|
from llama_index.core.base.base_retriever import BaseRetriever
|
|
from llama_index.core.callbacks.base import CallbackManager
|
|
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
|
|
from llama_index.core.indices.vector_store.base import VectorStoreIndex
|
|
from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
|
|
from llama_index.core.storage.docstore.types import BaseDocumentStore
|
|
from llama_index.core.vector_stores.utils import (
|
|
node_to_metadata_dict,
|
|
metadata_dict_to_node,
|
|
)
|
|
|
|
import bm25s
|
|
from app.engine.retriever.CHTokener import chTokenize
|
|
|
|
CHDEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"}
|
|
|
|
CHDEFAULT_PERSIST_FILENAME = "retriever.json"
|
|
|
|
class CHBM25Retriever(BaseRetriever):
|
|
def __init__(
|
|
self,
|
|
nodes: Optional[List[BaseNode]] = None,
|
|
existing_bm25: Optional[bm25s.BM25] = None,
|
|
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
|
callback_manager: Optional[CallbackManager] = None,
|
|
objects: Optional[List[IndexNode]] = None,
|
|
object_map: Optional[dict] = None,
|
|
verbose: bool = False,
|
|
) -> None:
|
|
self.similarity_top_k = similarity_top_k
|
|
if existing_bm25 is not None:
|
|
self.bm25 = existing_bm25
|
|
self.corpus = existing_bm25.corpus
|
|
else:
|
|
from nltk.corpus import stopwords
|
|
if nodes is None:
|
|
raise ValueError("Please pass nodes or an existing BM25 object.")
|
|
|
|
self.corpus = [node_to_metadata_dict(node) for node in nodes]
|
|
|
|
corpus_tokens = chTokenize(
|
|
[node.get_content() for node in nodes],
|
|
show_progress=verbose,
|
|
)
|
|
self.bm25 = bm25s.BM25()
|
|
self.bm25.index(corpus_tokens, show_progress=verbose)
|
|
super().__init__(
|
|
callback_manager=callback_manager,
|
|
object_map=object_map,
|
|
objects=objects,
|
|
verbose=verbose,
|
|
)
|
|
|
|
@classmethod
|
|
def from_defaults(
|
|
cls,
|
|
index: Optional[VectorStoreIndex] = None,
|
|
nodes: Optional[List[BaseNode]] = None,
|
|
docstore: Optional[BaseDocumentStore] = None,
|
|
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
|
verbose: bool = False,
|
|
) -> "CHBM25Retriever":
|
|
if sum(bool(val) for val in [index, nodes, docstore]) != 1:
|
|
raise ValueError("Please pass exactly one of index, nodes, or docstore.")
|
|
|
|
if index is not None:
|
|
docstore = index.docstore
|
|
|
|
if docstore is not None:
|
|
nodes = cast(List[BaseNode], list(docstore.docs.values()))
|
|
|
|
assert (
|
|
nodes is not None
|
|
), "Please pass exactly one of index, nodes, or docstore."
|
|
|
|
return cls(
|
|
nodes=nodes,
|
|
similarity_top_k=similarity_top_k,
|
|
verbose=verbose,
|
|
)
|
|
|
|
def get_persist_args(self) -> Dict[str, Any]:
|
|
"""Get Persist Args Dict to Save."""
|
|
return {
|
|
CHDEFAULT_PERSIST_ARGS[key]: getattr(self, key)
|
|
for key in CHDEFAULT_PERSIST_ARGS
|
|
if hasattr(self, key)
|
|
}
|
|
|
|
def persist(self, path: str, **kwargs: Any) -> None:
|
|
"""Persist the retriever to a directory."""
|
|
self.bm25.save(path, corpus=self.corpus, **kwargs)
|
|
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME), "w") as f:
|
|
json.dump(self.get_persist_args(), f, indent=2)
|
|
|
|
@classmethod
|
|
def from_persist_dir(cls, path: str, **kwargs: Any) -> "CHBM25Retriever":
|
|
"""Load the retriever from a directory."""
|
|
bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
|
|
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME)) as f:
|
|
retriever_data = json.load(f)
|
|
return cls(existing_bm25=bm25, **retriever_data)
|
|
|
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
|
query = query_bundle.query_str
|
|
tokenized_query = chTokenize(
|
|
query,show_progress=self._verbose
|
|
)
|
|
indexes, scores = self.bm25.retrieve(
|
|
tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
|
|
)
|
|
|
|
# batched, but only one query
|
|
indexes = indexes[0]
|
|
scores = scores[0]
|
|
|
|
nodes: List[NodeWithScore] = []
|
|
for idx, score in zip(indexes, scores):
|
|
# idx can be an int or a dict of the node
|
|
if isinstance(idx, dict):
|
|
node = metadata_dict_to_node(idx)
|
|
else:
|
|
node_dict = self.corpus[int(idx)]
|
|
node = metadata_dict_to_node(node_dict)
|
|
nodes.append(NodeWithScore(node=node, score=float(score)))
|
|
|
|
return nodes |