新增关键字检索类
This commit is contained in:
@@ -0,0 +1,133 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, cast
|
||||
|
||||
from llama_index.core.base.base_retriever import BaseRetriever
|
||||
from llama_index.core.callbacks.base import CallbackManager
|
||||
from llama_index.core.constants import DEFAULT_SIMILARITY_TOP_K
|
||||
from llama_index.core.indices.vector_store.base import VectorStoreIndex
|
||||
from llama_index.core.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle
|
||||
from llama_index.core.storage.docstore.types import BaseDocumentStore
|
||||
from llama_index.core.vector_stores.utils import (
|
||||
node_to_metadata_dict,
|
||||
metadata_dict_to_node,
|
||||
)
|
||||
|
||||
import bm25s
|
||||
from app.engine.retriever.CHTokener import chTokenize
|
||||
|
||||
CHDEFAULT_PERSIST_ARGS = {"similarity_top_k": "similarity_top_k", "_verbose": "verbose"}
|
||||
|
||||
CHDEFAULT_PERSIST_FILENAME = "retriever.json"
|
||||
|
||||
class CHBM25Retriever(BaseRetriever):
|
||||
def __init__(
|
||||
self,
|
||||
nodes: Optional[List[BaseNode]] = None,
|
||||
existing_bm25: Optional[bm25s.BM25] = None,
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
callback_manager: Optional[CallbackManager] = None,
|
||||
objects: Optional[List[IndexNode]] = None,
|
||||
object_map: Optional[dict] = None,
|
||||
verbose: bool = False,
|
||||
) -> None:
|
||||
self.similarity_top_k = similarity_top_k
|
||||
if existing_bm25 is not None:
|
||||
self.bm25 = existing_bm25
|
||||
self.corpus = existing_bm25.corpus
|
||||
else:
|
||||
from nltk.corpus import stopwords
|
||||
if nodes is None:
|
||||
raise ValueError("Please pass nodes or an existing BM25 object.")
|
||||
|
||||
self.corpus = [node_to_metadata_dict(node) for node in nodes]
|
||||
|
||||
corpus_tokens = chTokenize(
|
||||
[node.get_content() for node in nodes],
|
||||
show_progress=verbose,
|
||||
)
|
||||
self.bm25 = bm25s.BM25()
|
||||
self.bm25.index(corpus_tokens, show_progress=verbose)
|
||||
super().__init__(
|
||||
callback_manager=callback_manager,
|
||||
object_map=object_map,
|
||||
objects=objects,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_defaults(
|
||||
cls,
|
||||
index: Optional[VectorStoreIndex] = None,
|
||||
nodes: Optional[List[BaseNode]] = None,
|
||||
docstore: Optional[BaseDocumentStore] = None,
|
||||
similarity_top_k: int = DEFAULT_SIMILARITY_TOP_K,
|
||||
verbose: bool = False,
|
||||
) -> "CHBM25Retriever":
|
||||
if sum(bool(val) for val in [index, nodes, docstore]) != 1:
|
||||
raise ValueError("Please pass exactly one of index, nodes, or docstore.")
|
||||
|
||||
if index is not None:
|
||||
docstore = index.docstore
|
||||
|
||||
if docstore is not None:
|
||||
nodes = cast(List[BaseNode], list(docstore.docs.values()))
|
||||
|
||||
assert (
|
||||
nodes is not None
|
||||
), "Please pass exactly one of index, nodes, or docstore."
|
||||
|
||||
return cls(
|
||||
nodes=nodes,
|
||||
similarity_top_k=similarity_top_k,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
def get_persist_args(self) -> Dict[str, Any]:
|
||||
"""Get Persist Args Dict to Save."""
|
||||
return {
|
||||
CHDEFAULT_PERSIST_ARGS[key]: getattr(self, key)
|
||||
for key in CHDEFAULT_PERSIST_ARGS
|
||||
if hasattr(self, key)
|
||||
}
|
||||
|
||||
def persist(self, path: str, **kwargs: Any) -> None:
|
||||
"""Persist the retriever to a directory."""
|
||||
self.bm25.save(path, corpus=self.corpus, **kwargs)
|
||||
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME), "w") as f:
|
||||
json.dump(self.get_persist_args(), f, indent=2)
|
||||
|
||||
@classmethod
|
||||
def from_persist_dir(cls, path: str, **kwargs: Any) -> "CHBM25Retriever":
|
||||
"""Load the retriever from a directory."""
|
||||
bm25 = bm25s.BM25.load(path, load_corpus=True, **kwargs)
|
||||
with open(os.path.join(path, CHDEFAULT_PERSIST_FILENAME)) as f:
|
||||
retriever_data = json.load(f)
|
||||
return cls(existing_bm25=bm25, **retriever_data)
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
query = query_bundle.query_str
|
||||
tokenized_query = chTokenize(
|
||||
query,show_progress=self._verbose
|
||||
)
|
||||
indexes, scores = self.bm25.retrieve(
|
||||
tokenized_query, k=self.similarity_top_k, show_progress=self._verbose
|
||||
)
|
||||
|
||||
# batched, but only one query
|
||||
indexes = indexes[0]
|
||||
scores = scores[0]
|
||||
|
||||
nodes: List[NodeWithScore] = []
|
||||
for idx, score in zip(indexes, scores):
|
||||
# idx can be an int or a dict of the node
|
||||
if isinstance(idx, dict):
|
||||
node = metadata_dict_to_node(idx)
|
||||
else:
|
||||
node_dict = self.corpus[int(idx)]
|
||||
node = metadata_dict_to_node(node_dict)
|
||||
nodes.append(NodeWithScore(node=node, score=float(score)))
|
||||
|
||||
return nodes
|
||||
@@ -0,0 +1,46 @@
|
||||
from typing import Any, Dict, List, Union, Callable, NamedTuple
|
||||
from bm25s.tokenization import *
|
||||
|
||||
try:
|
||||
from tqdm.auto import tqdm
|
||||
except ImportError:
|
||||
|
||||
def tqdm(iterable, *args, **kwargs):
|
||||
return iterable
|
||||
|
||||
|
||||
def chinese_tokenizer(text: str) -> List[str]:
|
||||
import jieba
|
||||
from nltk.corpus import stopwords
|
||||
tokens = jieba.lcut(text)
|
||||
return [token for token in tokens if token not in stopwords.words('chinese')]
|
||||
|
||||
def chTokenize(
|
||||
texts,
|
||||
show_progress: bool = True,
|
||||
leave: bool = False,
|
||||
) -> Union[List[List[str]], Tokenized]:
|
||||
if isinstance(texts, str):
|
||||
texts = [texts]
|
||||
|
||||
corpus_ids = []
|
||||
token_to_index = {}
|
||||
|
||||
for text in tqdm(
|
||||
texts, desc="Split strings", leave=leave, disable=not show_progress
|
||||
):
|
||||
|
||||
splitted = chinese_tokenizer(text)
|
||||
doc_ids = []
|
||||
|
||||
for token in splitted:
|
||||
if token not in token_to_index:
|
||||
token_to_index[token] = len(token_to_index)
|
||||
|
||||
token_id = token_to_index[token]
|
||||
doc_ids.append(token_id)
|
||||
|
||||
corpus_ids.append(doc_ids)
|
||||
|
||||
return Tokenized(ids=corpus_ids, vocab=token_to_index)
|
||||
|
||||
Reference in New Issue
Block a user