from typing import Any, List, Optional from llama_index.core.postprocessor import SentenceTransformerRerank from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle from llama_index.core.callbacks import CBEventType, EventPayload from llama_index.core.bridge.pydantic import PrivateAttr class OllamaRerank(SentenceTransformerRerank): _score_threshold: float = PrivateAttr() def __init__( self, top_n: int = 2, model: str = "cross-encoder/stsb-distilroberta-base", device: Optional[str] = None, keep_retrieval_score: Optional[bool] = False, score_threshold:float = 0.3 ): self._score_threshold = score_threshold super().__init__(top_n,model,device,keep_retrieval_score) @classmethod def class_name(cls) -> str: return "OllamaRerank" def _postprocess_nodes( self, nodes: List[NodeWithScore], query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: if query_bundle is None: raise ValueError("Missing query bundle in extra info.") if len(nodes) == 0: return [] query_and_nodes = [ ( query_bundle.query_str, node.node.get_content(metadata_mode=MetadataMode.EMBED), ) for node in nodes ] with self.callback_manager.event( CBEventType.RERANKING, payload={ EventPayload.NODES: nodes, EventPayload.MODEL_NAME: self.model, EventPayload.QUERY_STR: query_bundle.query_str, EventPayload.TOP_K: self.top_n, }, ) as event: scores = self._model.predict(query_and_nodes) assert len(scores) == len(nodes) for node, score in zip(nodes, scores): if self.keep_retrieval_score: node.node.metadata["retrieval_score"] = node.score node.score = score for i in range(len(nodes)-1,-1,-1): node = nodes[i] if node.score < self._score_threshold: nodes.remove(node) new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[ : self.top_n ] event.on_end(payload={EventPayload.NODES: new_nodes}) return new_nodes