diff --git a/backend/app/xinference/base.py b/backend/app/xinference/base.py index 54cb16d..3a91374 100644 --- a/backend/app/xinference/base.py +++ b/backend/app/xinference/base.py @@ -13,6 +13,7 @@ from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReR from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle from pydantic import Field +from win32comext.shell.demos.IUniformResourceLocator import new_sh logger = logging.getLogger(__name__) @@ -165,17 +166,19 @@ class XinferenceRerank(BaseNodePostprocessor): self, model_uid: str, endpoint: str, - top_n: int = 3, - threshold: float = 0.3, + top_n: int = None, + threshold: float = None, return_documents: bool = False ): _model_uid = model_uid _endpoint = endpoint + _op_n = top_n + threshold = threshold generator, model_description = self.load_model( model_uid, endpoint ) self._generator = generator - super().__init__(top_n=top_n, model=model_uid, threshold = threshold, return_documents=return_documents) + super().__init__(top_n=top_n, model=model_uid, model_uid=model_uid, threshold = threshold, return_documents=return_documents) @classmethod def class_name(cls) -> str: @@ -216,8 +219,14 @@ class XinferenceRerank(BaseNodePostprocessor): new_node_with_score = NodeWithScore( node=nodes[result['index']].node, score=result['relevance_score'] ) - if new_node_with_score.score >=self.threshold: - new_nodes.append(new_node_with_score) + if self.threshold is not None: + if new_node_with_score.score >=self.threshold: + new_nodes.append(new_node_with_score) + + if self.top_n is not None: + if len(new_nodes) > self.top_n: + for index in new_nodes[5:-1]: + new_nodes.remove(-1) event.on_end(payload={EventPayload.NODES: new_nodes})