import requests from llama_index.postprocessor.xinference_rerank import XinferenceRerank from llama_index.core.bridge.pydantic import Field from typing import List, Optional from llama_index.core.bridge.pydantic import Field from llama_index.core.callbacks import CBEventType, EventPayload from llama_index.core.instrumentation import get_dispatcher from llama_index.core.instrumentation.events.rerank import ( ReRankEndEvent, ReRankStartEvent, ) from llama_index.core.schema import NodeWithScore, QueryBundle, MetadataMode dispatcher = get_dispatcher(__name__) class CustomXinFerenceRerank(XinferenceRerank): score_threshold: float = Field(default=0.3,description="分数阈值") def _postprocess_nodes( self, nodes: List[NodeWithScore], query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: dispatcher.event( ReRankStartEvent( query=query_bundle, nodes=nodes, top_n=self.top_n, model_name=self.model, ) ) if query_bundle is None: raise ValueError("Missing query bundle.") if len(nodes) == 0: return [] with self.callback_manager.event( CBEventType.RERANKING, payload={ EventPayload.NODES: nodes, EventPayload.MODEL_NAME: self.model, EventPayload.QUERY_STR: self.get_query_str(query_bundle), EventPayload.TOP_K: self.top_n, }, ) as event: headers = {"Content-Type": "application/json"} json_data = { "model": self.model, "query": self.get_query_str(query_bundle), "documents": [ node.node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes ], } response = requests.post( url=f"{self.base_url}/v1/rerank", headers=headers, json=json_data ) response.encoding = "utf-8" if response.status_code != 200: raise Exception( f"Xinference call failed with status code {response.status_code}." f"Details: {response.text}" ) rerank_nodes = [] for result in response.json()["results"]: node = NodeWithScore( node=nodes[result["index"]].node, score=result["relevance_score"] ) if node.score > self.score_threshold: rerank_nodes.append(node) if len(rerank_nodes) > self.top_n: rerank_nodes = sorted(rerank_nodes,key=lambda x:x.score,reverse = True)[:self.top_n] event.on_end(payload={EventPayload.NODES: rerank_nodes}) dispatcher.event(ReRankEndEvent(nodes=rerank_nodes)) return rerank_nodes