import os from typing import Any, Dict, List, Optional, Union, override import httpx from agno.document import Document from agno.reranker.base import Reranker from agno.utils.log import logger import requests from openai import OpenAIError, Omit class CustomReranker(Reranker): model: str = "BAAI/bge-reranker-v2-m3" api_key: Optional[str] = None base_url: Optional[Union[str, httpx.URL]] = None top_n: Optional[int] = None return_documents: Optional[bool] = None max_chunks_per_doc: Optional[int] = 1024 overlap_tokens: Optional[int] = 80 def __init__(self, *, api_key: str | None = None, base_url: str | httpx.URL | None = None, model: str | None = None, top_n: int | None = None, return_documents: bool | None = None, max_chunks_per_doc:bool | None = None, overlap_tokens: int | None = None, ): if api_key is None: api_key = os.environ.get("OPENAI_API_KEY") if api_key is None: raise OpenAIError( "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" ) self.api_key = api_key if base_url is None: base_url = os.environ.get("OPENAI_BASE_URL") if base_url is None: base_url = f"https://api.openai.com/v1" self.base_url = base_url self.model = model or self.model self.return_documents = return_documents or self.return_documents self.top_n = top_n or self.top_n self.overlap_tokens = overlap_tokens or self.overlap_tokens self.max_chunks_per_doc = max_chunks_per_doc or self.max_chunks_per_doc super().__init__() @property @override def auth_headers(self) -> dict[str, str]: api_key = self.api_key return {"Authorization": f"Bearer {api_key}"} @property @override def default_headers(self) -> dict[str, str | Omit]: return { **super().default_headers, "X-Stainless-Async": "false", "OpenAI-Organization": self.organization if self.organization is not None else Omit(), "OpenAI-Project": self.project if self.project is not None else Omit(), **self._custom_headers, } def _rerank(self, query: str, documents: List[Document]) -> List[Document]: # Validate input documents and top_n if not documents: return [] top_n = self.top_n if top_n and not (0 < top_n): logger.warning(f"top_n should be a positive integer, got {self.top_n}, setting top_n to None") top_n = None compressed_docs: list[Document] = [] _docs = [doc.content for doc in documents] payload = { "model": self.model, "query": query, "documents": _docs, "top_n": top_n, "return_documents": self.return_documents, "max_chunks_per_doc": self.max_chunks_per_doc, "overlap_tokens": self.overlap_tokens, } headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json" } url = f"{self.base_url}/v1/rerank" response = requests.request("POST", url, json=payload, headers=headers) print(response.text) #response = self.client.rerank(query=query, documents=_docs, model=self.model) for r in response.results: doc = documents[r.index] doc.reranking_score = r.relevance_score compressed_docs.append(doc) # Order by relevance score compressed_docs.sort( key=lambda x: x.reranking_score if x.reranking_score is not None else float("-inf"), reverse=True, ) # Limit to top_n if specified if top_n: compressed_docs = compressed_docs[:top_n] return compressed_docs def rerank(self, query: str, documents: List[Document]) -> List[Document]: try: return self._rerank(query=query, documents=documents) except Exception as e: logger.error(f"Error reranking documents: {e}. Returning original documents") return documents