diff --git a/backend/.env.example b/backend/.env.example index 855d9d7..9be4eb2 100644 --- a/backend/.env.example +++ b/backend/.env.example @@ -26,6 +26,7 @@ RERANK_ENABLED=true RERANK_PROVIDER=ollama RERANK_MODEL= /models/bge-reranker-base RERANK_TOP_N=5 +RERANK_THRESHOLD=0.3 #---------- model - Xinference ---------------- #MODEL_PROVIDER=xinference diff --git a/backend/app/engine/rerank/ollamRerank.py b/backend/app/engine/rerank/ollamRerank.py new file mode 100644 index 0000000..97f9ec0 --- /dev/null +++ b/backend/app/engine/rerank/ollamRerank.py @@ -0,0 +1,70 @@ +from typing import Any, List, Optional +from llama_index.core.postprocessor import SentenceTransformerRerank +from llama_index.core.schema import MetadataMode, NodeWithScore, QueryBundle +from llama_index.core.callbacks import CBEventType, EventPayload +from llama_index.core.bridge.pydantic import PrivateAttr + +class OllamaRerank(SentenceTransformerRerank): + _score_threshold: float = PrivateAttr() + def __init__( + self, + top_n: int = 2, + model: str = "cross-encoder/stsb-distilroberta-base", + device: Optional[str] = None, + keep_retrieval_score: Optional[bool] = False, + score_threshold:float = 0.3 + ): + self._score_threshold = score_threshold + super().__init__(top_n,model,device,keep_retrieval_score) + + @classmethod + def class_name(cls) -> str: + return "OllamaRerank" + + def _postprocess_nodes( + self, + nodes: List[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + ) -> List[NodeWithScore]: + if query_bundle is None: + raise ValueError("Missing query bundle in extra info.") + if len(nodes) == 0: + return [] + + query_and_nodes = [ + ( + query_bundle.query_str, + node.node.get_content(metadata_mode=MetadataMode.EMBED), + ) + for node in nodes + ] + + with self.callback_manager.event( + CBEventType.RERANKING, + payload={ + EventPayload.NODES: nodes, + EventPayload.MODEL_NAME: self.model, + EventPayload.QUERY_STR: query_bundle.query_str, + EventPayload.TOP_K: self.top_n, + }, + ) as event: + scores = self._model.predict(query_and_nodes) + + assert len(scores) == len(nodes) + + for node, score in zip(nodes, scores): + if self.keep_retrieval_score: + node.node.metadata["retrieval_score"] = node.score + node.score = score + + for i in range(len(nodes)-1,-1,-1): + node = nodes[i] + if node.score < self._score_threshold: + nodes.remove(node) + + new_nodes = sorted(nodes, key=lambda x: -x.score if x.score else 0)[ + : self.top_n + ] + event.on_end(payload={EventPayload.NODES: new_nodes}) + + return new_nodes \ No newline at end of file diff --git a/backend/app/settings.py b/backend/app/settings.py index 761a8c1..57f098c 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -88,7 +88,18 @@ class OllamaPlatform(ModelPlatform): pass def rerank(self): - pass + from app.engine.rerank.ollamRerank import OllamaRerank + modelpath = os.getcwd() + os.getenv('RERANK_MODEL') + top_n = os.getenv('RERANK_TOP_N',5) + threshold = float(os.getenv('RERANK_THRESHOLD',0.3)) + rerank = OllamaRerank( + model=modelpath, + top_n=top_n, + device="cpu", + score_threshold= threshold + ) + return [rerank] + @register(ModelPlateCategory,'xinference') class XinferencePlatform(ModelPlatform):