"""Xinference embeddings file.""" import logging from enum import Enum from http import HTTPStatus from typing import Any, Dict, List, Optional, Union, Tuple from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding, dispatcher from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.callbacks import CBEventType, EventPayload from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReRankEndEvent from llama_index.core.postprocessor.types import BaseNodePostprocessor from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle from pydantic import Field logger = logging.getLogger(__name__) EMBED_MAX_INPUT_LENGTH = 2048 EMBED_MAX_BATCH_SIZE = 1 class XinferenceEmbedding(BaseEmbedding): """Xinference class for text embedding. """ model_description: Dict[str, Any] = Field( description="The model description from Xinference." ) _generator: Any = PrivateAttr() _model_uid: str = Field(description="The Xinference model to use.") _endpoint: str = Field(description="The Xinference endpoint URL to use.") def __init__( self, model_uid: str, endpoint: str, embed_batch_size: int = EMBED_MAX_BATCH_SIZE, dimensions: Optional[int] = None, additional_kwargs: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None, api_base: Optional[str] = None, api_version: Optional[str] = None, max_retries: int = 10, # timeout: float = 60.0, # reuse_client: bool = True, # callback_manager: Optional[CallbackManager] = None, # default_headers: Optional[Dict[str, str]] = None, # http_client: Optional[httpx.Client] = None, # async_http_client: Optional[httpx.AsyncClient] = None, # num_workers: Optional[int] = None, **kwargs: Any, ) -> None: generator, model_description, embed_batch_size, dimensions = self.load_model( model_uid, endpoint ) self._generator = generator #self._model_uid = model_uid #self._endpoint = endpoint super().__init__( embed_batch_size=embed_batch_size, dimensions=dimensions, #callback_manager=callback_manager, model_name=model_uid, additional_kwargs=additional_kwargs, api_key=api_key, api_base=api_base, api_version=api_version, max_retries=max_retries, # reuse_client=reuse_client, # timeout=timeout, # default_headers=default_headers, # num_workers=num_workers, **kwargs, ) def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: try: from xinference.client import RESTfulClient except ImportError: raise ImportError( "Could not import Xinference library." 'Please install Xinference with `pip install "xinference[all]"`' ) client = RESTfulClient(endpoint) try: assert isinstance(client, RESTfulClient) except AssertionError: raise RuntimeError( "Could not create RESTfulClient instance." "Please make sure Xinference endpoint is running at the correct port." ) generator = client.get_model(model_uid) model_description = client.list_models()[model_uid] try: assert generator is not None assert model_description is not None except AssertionError: raise RuntimeError( "Could not get model from endpoint." "Please make sure Xinference endpoint is running at the correct port." ) model = model_description["model_name"] replica = model_description['replica'] dimensions = model_description['dimensions'] max_tokens = model_description['max_tokens'] return generator, model_description, replica, dimensions @classmethod def class_name(cls) -> str: return "XinferenceEmbedding" def _get_text_embedding(self, text: str) -> Embedding: """ Embed the input text synchronously. Subclasses should implement this method. Reference get_text_embedding's docstring for more information. """ assert self._generator is not None response = self._generator.create_embedding(input=text) return response['data'][0]['embedding'] def _get_query_embedding(self, query: str) -> Embedding: """ Embed the input query synchronously. Subclasses should implement this method. Reference get_query_embedding's docstring for more information. """ return self._get_text_embedding(query) async def _aget_query_embedding(self, query: str) -> Embedding: """ Embed the input query asynchronously. Subclasses should implement this method. Reference get_query_embedding's docstring for more information. """ return self._get_query_embedding(query) class XinferenceRerank(BaseNodePostprocessor): """Xinference class for rerank. """ model_description: Dict[str, Any] = Field( description="The model description from Xinference." ) _generator: Any = PrivateAttr() _model_uid: str = Field(description="The Xinference model to use.") _endpoint: str = Field(description="The Xinference endpoint URL to use.") model: str = Field(description="Dashscope rerank model name.") top_n: int = Field(description="Top N nodes to return.") threshold: float = Field(description="threshold nodes to return.") def __init__( self, model_uid: str, endpoint: str, top_n: int = 3, threshold: float = 0.3, return_documents: bool = False ): _model_uid = model_uid _endpoint = endpoint generator, model_description = self.load_model( model_uid, endpoint ) self._generator = generator super().__init__(top_n=top_n, model=model_uid, threshold = threshold, return_documents=return_documents) @classmethod def class_name(cls) -> str: return "XinferenceRerank" def _postprocess_nodes( self, nodes: List[NodeWithScore], query_bundle: Optional[QueryBundle] = None, ) -> List[NodeWithScore]: if query_bundle is None: raise ValueError("Missing query bundle in extra info.") if len(nodes) == 0: return [] dispatcher.event( ReRankStartEvent( nodes = nodes, top_n = self.top_n, query = query_bundle, model_name = self.model ) ) with self.callback_manager.event( CBEventType.RERANKING, payload={ EventPayload.NODES: nodes, EventPayload.MODEL_NAME: self._model_uid, EventPayload.QUERY_STR: query_bundle.query_str, EventPayload.TOP_K: self.top_n, }, ) as event: texts = [node.node.get_content() for node in nodes] response = self._generator.rerank(texts,query_bundle.query_str) new_nodes = [] for result in response['results']: new_node_with_score = NodeWithScore( node=nodes[result['index']].node, score=result['relevance_score'] ) if new_node_with_score.score >=self.threshold: new_nodes.append(new_node_with_score) event.on_end(payload={EventPayload.NODES: new_nodes}) dispatcher.event( ReRankEndEvent( nodes= new_nodes ) ) return new_nodes def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: try: from xinference.client import RESTfulClient except ImportError: raise ImportError( "Could not import Xinference library." 'Please install Xinference with `pip install "xinference[all]"`' ) client = RESTfulClient(endpoint) try: assert isinstance(client, RESTfulClient) except AssertionError: raise RuntimeError( "Could not create RESTfulClient instance." "Please make sure Xinference endpoint is running at the correct port." ) generator = client.get_model(model_uid) model_description = client.list_models()[model_uid] try: assert generator is not None assert model_description is not None except AssertionError: raise RuntimeError( "Could not get model from endpoint." "Please make sure Xinference endpoint is running at the correct port." ) model = model_description["model_name"] return generator, model_description