75 lines
2.9 KiB
Python
75 lines
2.9 KiB
Python
import requests
|
|
from llama_index.postprocessor.xinference_rerank import XinferenceRerank
|
|
from llama_index.core.bridge.pydantic import Field
|
|
from typing import List, Optional
|
|
from llama_index.core.bridge.pydantic import Field
|
|
from llama_index.core.callbacks import CBEventType, EventPayload
|
|
from llama_index.core.instrumentation import get_dispatcher
|
|
from llama_index.core.instrumentation.events.rerank import (
|
|
ReRankEndEvent,
|
|
ReRankStartEvent,
|
|
)
|
|
from llama_index.core.schema import NodeWithScore, QueryBundle, MetadataMode
|
|
dispatcher = get_dispatcher(__name__)
|
|
|
|
class CustomXinFerenceRerank(XinferenceRerank):
|
|
score_threshold: float = Field(default=0.3,description="分数阈值")
|
|
|
|
def _postprocess_nodes(
|
|
self,
|
|
nodes: List[NodeWithScore],
|
|
query_bundle: Optional[QueryBundle] = None,
|
|
) -> List[NodeWithScore]:
|
|
dispatcher.event(
|
|
ReRankStartEvent(
|
|
query=query_bundle,
|
|
nodes=nodes,
|
|
top_n=self.top_n,
|
|
model_name=self.model,
|
|
)
|
|
)
|
|
if query_bundle is None:
|
|
raise ValueError("Missing query bundle.")
|
|
if len(nodes) == 0:
|
|
return []
|
|
with self.callback_manager.event(
|
|
CBEventType.RERANKING,
|
|
payload={
|
|
EventPayload.NODES: nodes,
|
|
EventPayload.MODEL_NAME: self.model,
|
|
EventPayload.QUERY_STR: self.get_query_str(query_bundle),
|
|
EventPayload.TOP_K: self.top_n,
|
|
},
|
|
) as event:
|
|
headers = {"Content-Type": "application/json"}
|
|
json_data = {
|
|
"model": self.model,
|
|
"query": self.get_query_str(query_bundle),
|
|
"documents": [
|
|
node.node.get_content(metadata_mode=MetadataMode.EMBED)
|
|
for node in nodes
|
|
],
|
|
}
|
|
response = requests.post(
|
|
url=f"{self.base_url}/v1/rerank", headers=headers, json=json_data
|
|
)
|
|
response.encoding = "utf-8"
|
|
if response.status_code != 200:
|
|
raise Exception(
|
|
f"Xinference call failed with status code {response.status_code}."
|
|
f"Details: {response.text}"
|
|
)
|
|
|
|
rerank_nodes = []
|
|
for result in response.json()["results"]:
|
|
node = NodeWithScore(
|
|
node=nodes[result["index"]].node, score=result["relevance_score"]
|
|
)
|
|
if node.score > self.score_threshold:
|
|
rerank_nodes.append(node)
|
|
|
|
if len(rerank_nodes) > self.top_n:
|
|
rerank_nodes = sorted(rerank_nodes,key=lambda x:x.score,reverse = True)[:self.top_n]
|
|
event.on_end(payload={EventPayload.NODES: rerank_nodes})
|
|
dispatcher.event(ReRankEndEvent(nodes=rerank_nodes))
|
|
return rerank_nodes |