122 lines
4.3 KiB
Python
122 lines
4.3 KiB
Python
import os
|
|
from typing import Any, Dict, List, Optional, Union, override
|
|
|
|
import httpx
|
|
from agno.document import Document
|
|
from agno.reranker.base import Reranker
|
|
from agno.utils.log import logger
|
|
import requests
|
|
from openai import OpenAIError, Omit
|
|
|
|
|
|
class CustomReranker(Reranker):
|
|
model: str = "BAAI/bge-reranker-v2-m3"
|
|
api_key: Optional[str] = None
|
|
base_url: Optional[Union[str, httpx.URL]] = None
|
|
top_n: Optional[int] = None
|
|
return_documents: Optional[bool] = None
|
|
max_chunks_per_doc: Optional[int] = 1024
|
|
overlap_tokens: Optional[int] = 80
|
|
|
|
def __init__(self,
|
|
*,
|
|
api_key: str | None = None,
|
|
base_url: str | httpx.URL | None = None,
|
|
model: str | None = None,
|
|
top_n: int | None = None,
|
|
return_documents: bool | None = None,
|
|
max_chunks_per_doc:bool | None = None,
|
|
overlap_tokens: int | None = None,
|
|
):
|
|
if api_key is None:
|
|
api_key = os.environ.get("OPENAI_API_KEY")
|
|
if api_key is None:
|
|
raise OpenAIError(
|
|
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
|
|
)
|
|
self.api_key = api_key
|
|
|
|
if base_url is None:
|
|
base_url = os.environ.get("OPENAI_BASE_URL")
|
|
if base_url is None:
|
|
base_url = f"https://api.openai.com/v1"
|
|
self.base_url = base_url
|
|
self.model = model or self.model
|
|
self.return_documents = return_documents or self.return_documents
|
|
self.top_n = top_n or self.top_n
|
|
self.overlap_tokens = overlap_tokens or self.overlap_tokens
|
|
self.max_chunks_per_doc = max_chunks_per_doc or self.max_chunks_per_doc
|
|
|
|
super().__init__()
|
|
|
|
@property
|
|
@override
|
|
def auth_headers(self) -> dict[str, str]:
|
|
api_key = self.api_key
|
|
return {"Authorization": f"Bearer {api_key}"}
|
|
|
|
@property
|
|
@override
|
|
def default_headers(self) -> dict[str, str | Omit]:
|
|
return {
|
|
**super().default_headers,
|
|
"X-Stainless-Async": "false",
|
|
"OpenAI-Organization": self.organization if self.organization is not None else Omit(),
|
|
"OpenAI-Project": self.project if self.project is not None else Omit(),
|
|
**self._custom_headers,
|
|
}
|
|
def _rerank(self, query: str, documents: List[Document]) -> List[Document]:
|
|
# Validate input documents and top_n
|
|
if not documents:
|
|
return []
|
|
|
|
top_n = self.top_n
|
|
if top_n and not (0 < top_n):
|
|
logger.warning(f"top_n should be a positive integer, got {self.top_n}, setting top_n to None")
|
|
top_n = None
|
|
|
|
compressed_docs: list[Document] = []
|
|
_docs = [doc.content for doc in documents]
|
|
|
|
payload = {
|
|
"model": self.model,
|
|
"query": query,
|
|
"documents": _docs,
|
|
"top_n": top_n,
|
|
"return_documents": self.return_documents,
|
|
"max_chunks_per_doc": self.max_chunks_per_doc,
|
|
"overlap_tokens": self.overlap_tokens,
|
|
}
|
|
headers = {
|
|
"Authorization": f"Bearer {self.api_key}",
|
|
"Content-Type": "application/json"
|
|
}
|
|
|
|
url = f"{self.base_url}/v1/rerank"
|
|
response = requests.request("POST", url, json=payload, headers=headers)
|
|
print(response.text)
|
|
#response = self.client.rerank(query=query, documents=_docs, model=self.model)
|
|
for r in response.results:
|
|
doc = documents[r.index]
|
|
doc.reranking_score = r.relevance_score
|
|
compressed_docs.append(doc)
|
|
|
|
# Order by relevance score
|
|
compressed_docs.sort(
|
|
key=lambda x: x.reranking_score if x.reranking_score is not None else float("-inf"),
|
|
reverse=True,
|
|
)
|
|
|
|
# Limit to top_n if specified
|
|
if top_n:
|
|
compressed_docs = compressed_docs[:top_n]
|
|
|
|
return compressed_docs
|
|
|
|
def rerank(self, query: str, documents: List[Document]) -> List[Document]:
|
|
try:
|
|
return self._rerank(query=query, documents=documents)
|
|
except Exception as e:
|
|
logger.error(f"Error reranking documents: {e}. Returning original documents")
|
|
return documents
|