迁入项目
This commit is contained in:
@@ -0,0 +1,121 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, override
|
||||
|
||||
import httpx
|
||||
from agno.document import Document
|
||||
from agno.reranker.base import Reranker
|
||||
from agno.utils.log import logger
|
||||
import requests
|
||||
from openai import OpenAIError, Omit
|
||||
|
||||
|
||||
class CustomReranker(Reranker):
|
||||
model: str = "BAAI/bge-reranker-v2-m3"
|
||||
api_key: Optional[str] = None
|
||||
base_url: Optional[Union[str, httpx.URL]] = None
|
||||
top_n: Optional[int] = None
|
||||
return_documents: Optional[bool] = None
|
||||
max_chunks_per_doc: Optional[int] = 1024
|
||||
overlap_tokens: Optional[int] = 80
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
api_key: str | None = None,
|
||||
base_url: str | httpx.URL | None = None,
|
||||
model: str | None = None,
|
||||
top_n: int | None = None,
|
||||
return_documents: bool | None = None,
|
||||
max_chunks_per_doc:bool | None = None,
|
||||
overlap_tokens: int | None = None,
|
||||
):
|
||||
if api_key is None:
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
if api_key is None:
|
||||
raise OpenAIError(
|
||||
"The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable"
|
||||
)
|
||||
self.api_key = api_key
|
||||
|
||||
if base_url is None:
|
||||
base_url = os.environ.get("OPENAI_BASE_URL")
|
||||
if base_url is None:
|
||||
base_url = f"https://api.openai.com/v1"
|
||||
self.base_url = base_url
|
||||
self.model = model or self.model
|
||||
self.return_documents = return_documents or self.return_documents
|
||||
self.top_n = top_n or self.top_n
|
||||
self.overlap_tokens = overlap_tokens or self.overlap_tokens
|
||||
self.max_chunks_per_doc = max_chunks_per_doc or self.max_chunks_per_doc
|
||||
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
@override
|
||||
def auth_headers(self) -> dict[str, str]:
|
||||
api_key = self.api_key
|
||||
return {"Authorization": f"Bearer {api_key}"}
|
||||
|
||||
@property
|
||||
@override
|
||||
def default_headers(self) -> dict[str, str | Omit]:
|
||||
return {
|
||||
**super().default_headers,
|
||||
"X-Stainless-Async": "false",
|
||||
"OpenAI-Organization": self.organization if self.organization is not None else Omit(),
|
||||
"OpenAI-Project": self.project if self.project is not None else Omit(),
|
||||
**self._custom_headers,
|
||||
}
|
||||
def _rerank(self, query: str, documents: List[Document]) -> List[Document]:
|
||||
# Validate input documents and top_n
|
||||
if not documents:
|
||||
return []
|
||||
|
||||
top_n = self.top_n
|
||||
if top_n and not (0 < top_n):
|
||||
logger.warning(f"top_n should be a positive integer, got {self.top_n}, setting top_n to None")
|
||||
top_n = None
|
||||
|
||||
compressed_docs: list[Document] = []
|
||||
_docs = [doc.content for doc in documents]
|
||||
|
||||
payload = {
|
||||
"model": self.model,
|
||||
"query": query,
|
||||
"documents": _docs,
|
||||
"top_n": top_n,
|
||||
"return_documents": self.return_documents,
|
||||
"max_chunks_per_doc": self.max_chunks_per_doc,
|
||||
"overlap_tokens": self.overlap_tokens,
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
url = f"{self.base_url}/v1/rerank"
|
||||
response = requests.request("POST", url, json=payload, headers=headers)
|
||||
print(response.text)
|
||||
#response = self.client.rerank(query=query, documents=_docs, model=self.model)
|
||||
for r in response.results:
|
||||
doc = documents[r.index]
|
||||
doc.reranking_score = r.relevance_score
|
||||
compressed_docs.append(doc)
|
||||
|
||||
# Order by relevance score
|
||||
compressed_docs.sort(
|
||||
key=lambda x: x.reranking_score if x.reranking_score is not None else float("-inf"),
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
# Limit to top_n if specified
|
||||
if top_n:
|
||||
compressed_docs = compressed_docs[:top_n]
|
||||
|
||||
return compressed_docs
|
||||
|
||||
def rerank(self, query: str, documents: List[Document]) -> List[Document]:
|
||||
try:
|
||||
return self._rerank(query=query, documents=documents)
|
||||
except Exception as e:
|
||||
logger.error(f"Error reranking documents: {e}. Returning original documents")
|
||||
return documents
|
||||
Reference in New Issue
Block a user