Files
2025-04-08 11:38:01 +08:00

122 lines
4.3 KiB
Python

import os
from typing import Any, Dict, List, Optional, Union, override
import httpx
from agno.document import Document
from agno.reranker.base import Reranker
from agno.utils.log import logger
import requests
from openai import OpenAIError, Omit
class CustomReranker(Reranker):
model: str = "BAAI/bge-reranker-v2-m3"
api_key: Optional[str] = None
base_url: Optional[Union[str, httpx.URL]] = None
top_n: Optional[int] = None
return_documents: Optional[bool] = None
max_chunks_per_doc: Optional[int] = 1024
overlap_tokens: Optional[int] = 80
def __init__(self,
*,
api_key: str | None = None,
base_url: str | httpx.URL | None = None,
model: str | None = None,
top_n: int | None = None,
return_documents: bool | None = None,
max_chunks_per_doc:bool | None = None,
overlap_tokens: int | None = None,
):
if api_key is None:
api_key = os.environ.get("OPENAI_API_KEY")
if api_key is None:
raise OpenAIError(
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
)
self.api_key = api_key
if base_url is None:
base_url = os.environ.get("OPENAI_BASE_URL")
if base_url is None:
base_url = f"https://api.openai.com/v1"
self.base_url = base_url
self.model = model or self.model
self.return_documents = return_documents or self.return_documents
self.top_n = top_n or self.top_n
self.overlap_tokens = overlap_tokens or self.overlap_tokens
self.max_chunks_per_doc = max_chunks_per_doc or self.max_chunks_per_doc
super().__init__()
@property
@override
def auth_headers(self) -> dict[str, str]:
api_key = self.api_key
return {"Authorization": f"Bearer {api_key}"}
@property
@override
def default_headers(self) -> dict[str, str | Omit]:
return {
**super().default_headers,
"X-Stainless-Async": "false",
"OpenAI-Organization": self.organization if self.organization is not None else Omit(),
"OpenAI-Project": self.project if self.project is not None else Omit(),
**self._custom_headers,
}
def _rerank(self, query: str, documents: List[Document]) -> List[Document]:
# Validate input documents and top_n
if not documents:
return []
top_n = self.top_n
if top_n and not (0 < top_n):
logger.warning(f"top_n should be a positive integer, got {self.top_n}, setting top_n to None")
top_n = None
compressed_docs: list[Document] = []
_docs = [doc.content for doc in documents]
payload = {
"model": self.model,
"query": query,
"documents": _docs,
"top_n": top_n,
"return_documents": self.return_documents,
"max_chunks_per_doc": self.max_chunks_per_doc,
"overlap_tokens": self.overlap_tokens,
}
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
url = f"{self.base_url}/v1/rerank"
response = requests.request("POST", url, json=payload, headers=headers)
print(response.text)
#response = self.client.rerank(query=query, documents=_docs, model=self.model)
for r in response.results:
doc = documents[r.index]
doc.reranking_score = r.relevance_score
compressed_docs.append(doc)
# Order by relevance score
compressed_docs.sort(
key=lambda x: x.reranking_score if x.reranking_score is not None else float("-inf"),
reverse=True,
)
# Limit to top_n if specified
if top_n:
compressed_docs = compressed_docs[:top_n]
return compressed_docs
def rerank(self, query: str, documents: List[Document]) -> List[Document]:
try:
return self._rerank(query=query, documents=documents)
except Exception as e:
logger.error(f"Error reranking documents: {e}. Returning original documents")
return documents