dev #5
@@ -13,6 +13,7 @@ from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReR
|
|||||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||||
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
|
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
|
||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
|
from win32comext.shell.demos.IUniformResourceLocator import new_sh
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -165,17 +166,19 @@ class XinferenceRerank(BaseNodePostprocessor):
|
|||||||
self,
|
self,
|
||||||
model_uid: str,
|
model_uid: str,
|
||||||
endpoint: str,
|
endpoint: str,
|
||||||
top_n: int = 3,
|
top_n: int = None,
|
||||||
threshold: float = 0.3,
|
threshold: float = None,
|
||||||
return_documents: bool = False
|
return_documents: bool = False
|
||||||
):
|
):
|
||||||
_model_uid = model_uid
|
_model_uid = model_uid
|
||||||
_endpoint = endpoint
|
_endpoint = endpoint
|
||||||
|
_op_n = top_n
|
||||||
|
threshold = threshold
|
||||||
generator, model_description = self.load_model(
|
generator, model_description = self.load_model(
|
||||||
model_uid, endpoint
|
model_uid, endpoint
|
||||||
)
|
)
|
||||||
self._generator = generator
|
self._generator = generator
|
||||||
super().__init__(top_n=top_n, model=model_uid, threshold = threshold, return_documents=return_documents)
|
super().__init__(top_n=top_n, model=model_uid, model_uid=model_uid, threshold = threshold, return_documents=return_documents)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def class_name(cls) -> str:
|
def class_name(cls) -> str:
|
||||||
@@ -216,9 +219,15 @@ class XinferenceRerank(BaseNodePostprocessor):
|
|||||||
new_node_with_score = NodeWithScore(
|
new_node_with_score = NodeWithScore(
|
||||||
node=nodes[result['index']].node, score=result['relevance_score']
|
node=nodes[result['index']].node, score=result['relevance_score']
|
||||||
)
|
)
|
||||||
|
if self.threshold is not None:
|
||||||
if new_node_with_score.score >=self.threshold:
|
if new_node_with_score.score >=self.threshold:
|
||||||
new_nodes.append(new_node_with_score)
|
new_nodes.append(new_node_with_score)
|
||||||
|
|
||||||
|
if self.top_n is not None:
|
||||||
|
if len(new_nodes) > self.top_n:
|
||||||
|
for index in new_nodes[5:-1]:
|
||||||
|
new_nodes.remove(-1)
|
||||||
|
|
||||||
event.on_end(payload={EventPayload.NODES: new_nodes})
|
event.on_end(payload={EventPayload.NODES: new_nodes})
|
||||||
|
|
||||||
dispatcher.event(
|
dispatcher.event(
|
||||||
|
|||||||
Reference in New Issue
Block a user