From bbe3fd0b0b6b99435547db1aae1fa0c6878c32af Mon Sep 17 00:00:00 2001 From: paituo <330435863@qq.com> Date: Mon, 19 Aug 2024 15:39:28 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9XinferenceRerank=E7=B1=BB?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E6=9C=80=E5=A4=9AN=E6=9D=A1=E7=9F=A5?= =?UTF-8?q?=E8=AF=86=E5=92=8C=E6=9C=80=E5=B0=8F=E5=8C=B9=E9=85=8D=E5=BA=A6?= =?UTF-8?q?=E8=BF=87=E6=BB=A4=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/xinference/base.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/backend/app/xinference/base.py b/backend/app/xinference/base.py index 54cb16d..3a91374 100644 --- a/backend/app/xinference/base.py +++ b/backend/app/xinference/base.py @@ -13,6 +13,7 @@ from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReR from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle from pydantic import Field +from win32comext.shell.demos.IUniformResourceLocator import new_sh logger = logging.getLogger(__name__) @@ -165,17 +166,19 @@ class XinferenceRerank(BaseNodePostprocessor): self, model_uid: str, endpoint: str, - top_n: int = 3, - threshold: float = 0.3, + top_n: int = None, + threshold: float = None, return_documents: bool = False ): _model_uid = model_uid _endpoint = endpoint + _op_n = top_n + threshold = threshold generator, model_description = self.load_model( model_uid, endpoint ) self._generator = generator - super().__init__(top_n=top_n, model=model_uid, threshold = threshold, return_documents=return_documents) + super().__init__(top_n=top_n, model=model_uid, model_uid=model_uid, threshold = threshold, return_documents=return_documents) @classmethod def class_name(cls) -> str: @@ -216,8 +219,14 @@ class XinferenceRerank(BaseNodePostprocessor): new_node_with_score = NodeWithScore( node=nodes[result['index']].node, score=result['relevance_score'] ) - if new_node_with_score.score >=self.threshold: - new_nodes.append(new_node_with_score) + if self.threshold is not None: + if new_node_with_score.score >=self.threshold: + new_nodes.append(new_node_with_score) + + if self.top_n is not None: + if len(new_nodes) > self.top_n: + for index in new_nodes[5:-1]: + new_nodes.remove(-1) event.on_end(payload={EventPayload.NODES: new_nodes})