Compare commits

..

4 Commits

2 changed files with 15 additions and 6 deletions
+14 -5
View File
@@ -13,6 +13,7 @@ from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReR
from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.postprocessor.types import BaseNodePostprocessor
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
from pydantic import Field from pydantic import Field
from win32comext.shell.demos.IUniformResourceLocator import new_sh
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -165,17 +166,19 @@ class XinferenceRerank(BaseNodePostprocessor):
self, self,
model_uid: str, model_uid: str,
endpoint: str, endpoint: str,
top_n: int = 3, top_n: int = None,
threshold: float = 0.3, threshold: float = None,
return_documents: bool = False return_documents: bool = False
): ):
_model_uid = model_uid _model_uid = model_uid
_endpoint = endpoint _endpoint = endpoint
_op_n = top_n
threshold = threshold
generator, model_description = self.load_model( generator, model_description = self.load_model(
model_uid, endpoint model_uid, endpoint
) )
self._generator = generator self._generator = generator
super().__init__(top_n=top_n, model=model_uid, threshold = threshold, return_documents=return_documents) super().__init__(top_n=top_n, model=model_uid, model_uid=model_uid, threshold = threshold, return_documents=return_documents)
@classmethod @classmethod
def class_name(cls) -> str: def class_name(cls) -> str:
@@ -216,8 +219,14 @@ class XinferenceRerank(BaseNodePostprocessor):
new_node_with_score = NodeWithScore( new_node_with_score = NodeWithScore(
node=nodes[result['index']].node, score=result['relevance_score'] node=nodes[result['index']].node, score=result['relevance_score']
) )
if new_node_with_score.score >=self.threshold: if self.threshold is not None:
new_nodes.append(new_node_with_score) if new_node_with_score.score >=self.threshold:
new_nodes.append(new_node_with_score)
if self.top_n is not None:
if len(new_nodes) > self.top_n:
for index in new_nodes[5:-1]:
new_nodes.remove(-1)
event.on_end(payload={EventPayload.NODES: new_nodes}) event.on_end(payload={EventPayload.NODES: new_nodes})
+1 -1
View File
@@ -12,7 +12,7 @@ db:
queries: queries:
- select * from ProjectProperties limit 30; - select * from ProjectProperties limit 30;
- select Name, Code, Amount, Amount_Total from TotalCalculateTable - select Name, Code, Amount, Amount_Total from TotalCalculateTable
- select SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 1 limit 30; - select SerialNumber, Name, Quantity, Rate, Sum_Price from ProjectDivision where Level = 3 limit 50;
- select Name, Code, Rate, Amount from OtherFee - select Name, Code, Rate, Amount from OtherFee
#web: #web: