dev #5
@@ -0,0 +1,133 @@
|
|||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
|
||||||
|
from typing import Any, Callable, Dict, List, Optional, cast
|
||||||
|
|
||||||
|
from llama_index.core.base.base_retriever import BaseRetriever
|
||||||
|
from llama_index.core.callbacks.base import CallbackManager
|
||||||
|
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
|
||||||
|
from llama_index.core.indices.vector_store.base import VectorStoreIndex
|
||||||
|
from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
|
||||||
|
from llama_index.core.storage.docstore.types import BaseDocumentStore
|
||||||
|
from llama_index.core.vector_stores.utils import (
|
||||||
|
node_to_metadata_dict,
|
||||||
|
metadata_dict_to_node,
|
||||||
|
)
|
||||||
|
|
||||||
|
import bm25s
|
||||||
|
from app.engine.retriever.CHTokener import chTokenize
|
||||||
|
|
||||||
|
CHDEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"}
|
||||||
|
|
||||||
|
CHDEFAULT_PERSIST_FILENAME = "retriever.json"
|
||||||
|
|
||||||
|
class CHBM25Retriever(BaseRetriever):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
nodes: Optional[List[BaseNode]] = None,
|
||||||
|
existing_bm25: Optional[bm25s.BM25] = None,
|
||||||
|
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||||
|
callback_manager: Optional[CallbackManager] = None,
|
||||||
|
objects: Optional[List[IndexNode]] = None,
|
||||||
|
object_map: Optional[dict] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.similarity_top_k = similarity_top_k
|
||||||
|
if existing_bm25 is not None:
|
||||||
|
self.bm25 = existing_bm25
|
||||||
|
self.corpus = existing_bm25.corpus
|
||||||
|
else:
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
if nodes is None:
|
||||||
|
raise ValueError("Please pass nodes or an existing BM25 object.")
|
||||||
|
|
||||||
|
self.corpus = [node_to_metadata_dict(node) for node in nodes]
|
||||||
|
|
||||||
|
corpus_tokens = chTokenize(
|
||||||
|
[node.get_content() for node in nodes],
|
||||||
|
show_progress=verbose,
|
||||||
|
)
|
||||||
|
self.bm25 = bm25s.BM25()
|
||||||
|
self.bm25.index(corpus_tokens, show_progress=verbose)
|
||||||
|
super().__init__(
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
object_map=object_map,
|
||||||
|
objects=objects,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_defaults(
|
||||||
|
cls,
|
||||||
|
index: Optional[VectorStoreIndex] = None,
|
||||||
|
nodes: Optional[List[BaseNode]] = None,
|
||||||
|
docstore: Optional[BaseDocumentStore] = None,
|
||||||
|
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||||
|
verbose: bool = False,
|
||||||
|
) -> "CHBM25Retriever":
|
||||||
|
if sum(bool(val) for val in [index, nodes, docstore]) != 1:
|
||||||
|
raise ValueError("Please pass exactly one of index, nodes, or docstore.")
|
||||||
|
|
||||||
|
if index is not None:
|
||||||
|
docstore = index.docstore
|
||||||
|
|
||||||
|
if docstore is not None:
|
||||||
|
nodes = cast(List[BaseNode], list(docstore.docs.values()))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
nodes is not None
|
||||||
|
), "Please pass exactly one of index, nodes, or docstore."
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
nodes=nodes,
|
||||||
|
similarity_top_k=similarity_top_k,
|
||||||
|
verbose=verbose,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_persist_args(self) -> Dict[str, Any]:
|
||||||
|
"""Get Persist Args Dict to Save."""
|
||||||
|
return {
|
||||||
|
CHDEFAULT_PERSIST_ARGS[key]: getattr(self, key)
|
||||||
|
for key in CHDEFAULT_PERSIST_ARGS
|
||||||
|
if hasattr(self, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
def persist(self, path: str, **kwargs: Any) -> None:
|
||||||
|
"""Persist the retriever to a directory."""
|
||||||
|
self.bm25.save(path, corpus=self.corpus, **kwargs)
|
||||||
|
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME), "w") as f:
|
||||||
|
json.dump(self.get_persist_args(), f, indent=2)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_persist_dir(cls, path: str, **kwargs: Any) -> "CHBM25Retriever":
|
||||||
|
"""Load the retriever from a directory."""
|
||||||
|
bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
|
||||||
|
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME)) as f:
|
||||||
|
retriever_data = json.load(f)
|
||||||
|
return cls(existing_bm25=bm25, **retriever_data)
|
||||||
|
|
||||||
|
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||||
|
query = query_bundle.query_str
|
||||||
|
tokenized_query = chTokenize(
|
||||||
|
query,show_progress=self._verbose
|
||||||
|
)
|
||||||
|
indexes, scores = self.bm25.retrieve(
|
||||||
|
tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
|
||||||
|
)
|
||||||
|
|
||||||
|
# batched, but only one query
|
||||||
|
indexes = indexes[0]
|
||||||
|
scores = scores[0]
|
||||||
|
|
||||||
|
nodes: List[NodeWithScore] = []
|
||||||
|
for idx, score in zip(indexes, scores):
|
||||||
|
# idx can be an int or a dict of the node
|
||||||
|
if isinstance(idx, dict):
|
||||||
|
node = metadata_dict_to_node(idx)
|
||||||
|
else:
|
||||||
|
node_dict = self.corpus[int(idx)]
|
||||||
|
node = metadata_dict_to_node(node_dict)
|
||||||
|
nodes.append(NodeWithScore(node=node, score=float(score)))
|
||||||
|
|
||||||
|
return nodes
|
||||||
@@ -0,0 +1,46 @@
|
|||||||
|
from typing import Any, Dict, List, Union, Callable, NamedTuple
|
||||||
|
from bm25s.tokenization import *
|
||||||
|
|
||||||
|
try:
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
except ImportError:
|
||||||
|
|
||||||
|
def tqdm(iterable, *args, **kwargs):
|
||||||
|
return iterable
|
||||||
|
|
||||||
|
|
||||||
|
def chinese_tokenizer(text: str) -> List[str]:
|
||||||
|
import jieba
|
||||||
|
from nltk.corpus import stopwords
|
||||||
|
tokens = jieba.lcut(text)
|
||||||
|
return [token for token in tokens if token not in stopwords.words('chinese')]
|
||||||
|
|
||||||
|
def chTokenize(
|
||||||
|
texts,
|
||||||
|
show_progress: bool = True,
|
||||||
|
leave: bool = False,
|
||||||
|
) -> Union[List[List[str]], Tokenized]:
|
||||||
|
if isinstance(texts, str):
|
||||||
|
texts = [texts]
|
||||||
|
|
||||||
|
corpus_ids = []
|
||||||
|
token_to_index = {}
|
||||||
|
|
||||||
|
for text in tqdm(
|
||||||
|
texts, desc="Split strings", leave=leave, disable=not show_progress
|
||||||
|
):
|
||||||
|
|
||||||
|
splitted = chinese_tokenizer(text)
|
||||||
|
doc_ids = []
|
||||||
|
|
||||||
|
for token in splitted:
|
||||||
|
if token not in token_to_index:
|
||||||
|
token_to_index[token] = len(token_to_index)
|
||||||
|
|
||||||
|
token_id = token_to_index[token]
|
||||||
|
doc_ids.append(token_id)
|
||||||
|
|
||||||
|
corpus_ids.append(doc_ids)
|
||||||
|
|
||||||
|
return Tokenized(ids=corpus_ids, vocab=token_to_index)
|
||||||
|
|
||||||
Reference in New Issue
Block a user