"""Xinference embeddings file.""" import logging from enum import Enum from http import HTTPStatus from typing import Any, Dict, List, Optional, Union, Tuple from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding from llama_index.core.bridge.pydantic import PrivateAttr from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding from llama_index.core.schema import ImageType from pydantic import Field logger = logging.getLogger(__name__) # class XinferenceTextEmbeddingType(str, Enum): # """DashScope TextEmbedding text_type.""" # # TEXT_TYPE_QUERY = "query" # TEXT_TYPE_DOCUMENT = "document" # # # class DashScopeTextEmbeddingModels(str, Enum): # """DashScope TextEmbedding models.""" # # TEXT_EMBEDDING_V1 = "text-embedding-v1" # TEXT_EMBEDDING_V2 = "text-embedding-v2" # TEXT_EMBEDDING_V3 = "text-embedding-v3" # # # class DashScopeBatchTextEmbeddingModels(str, Enum): # """DashScope TextEmbedding models.""" # # TEXT_EMBEDDING_ASYNC_V1 = "text-embedding-async-v1" # TEXT_EMBEDDING_ASYNC_V2 = "text-embedding-async-v2" # TEXT_EMBEDDING_ASYNC_V3 = "text-embedding-async-v3" EMBED_MAX_INPUT_LENGTH = 2048 EMBED_MAX_BATCH_SIZE = 1 # class DashScopeMultiModalEmbeddingModels(str, Enum): # """DashScope MultiModalEmbedding models.""" # # MULTIMODAL_EMBEDDING_ONE_PEACE_V1 = "multimodal-embedding-one-peace-v1" # def get_text_embedding( # model: str, # text: Union[str, List[str]], # api_key: Optional[str] = None, # **kwargs: Any, # ) -> List[List[float]]: # """Call DashScope text embedding. # ref: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details. # # Args: # model (str): The `DashScopeTextEmbeddingModels` # text (Union[str, List[str]]): text or list text to embedding. # # Raises: # ImportError: need import dashscope # # Returns: # List[List[float]]: The list of embedding result, if failed return empty list. # if some of test no output, the correspond index of output is None. # """ # try: # import dashscope # except ImportError: # raise ImportError("DashScope requires `pip install dashscope") # if isinstance(text, str): # text = [text] # response = dashscope.TextEmbedding.call( # model=model, input=text, api_key=api_key, kwargs=kwargs # ) # embedding_results = [None] * len(text) # if response.status_code == HTTPStatus.OK: # for emb in response.output["embeddings"]: # embedding_results[emb["text_index"]] = emb["embedding"] # else: # logger.error("Calling TextEmbedding failed, details: %s" % response) # # return embedding_results # # # def get_batch_text_embedding( # model: str, url: str, api_key: Optional[str] = None, **kwargs: Any # ) -> Optional[str]: # """Call DashScope batch text embedding. # # Args: # model (str): The `DashScopeMultiModalEmbeddingModels` # url (str): The url of the file to embedding which with lines of text to embedding. # # Raises: # ImportError: Need install dashscope package. # # Returns: # str: The url of the embedding result, format ref: # https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details # """ # try: # import dashscope # except ImportError: # raise ImportError("DashScope requires `pip install dashscope") # response = dashscope.BatchTextEmbedding.call( # model=model, url=url, api_key=api_key, kwargs=kwargs # ) # if response.status_code == HTTPStatus.OK: # return response.output["url"] # else: # logger.error("Calling BatchTextEmbedding failed, details: %s" % response) # return None # def get_multimodal_embedding( # model: str, input: list, api_key: Optional[str] = None, **kwargs: Any # ) -> List[float]: # """Call DashScope multimodal embedding. # ref: https://help.aliyun.com/zh/dashscope/developer-reference/one-peace-multimodal-embedding-api-details. # # Args: # model (str): The `DashScopeBatchTextEmbeddingModels` # input (str): The input of the embedding, eg: # [{'factor': 1, 'text': '你好'}, # {'factor': 2, 'audio': 'https://dashscope.oss-cn-beijing.aliyuncs.com/audios/cow.flac'}, # {'factor': 3, 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png'}] # # Raises: # ImportError: Need install dashscope package. # # Returns: # List[float]: Embedding result, if failed return empty list. # """ # try: # import dashscope # except ImportError: # raise ImportError("DashScope requires `pip install dashscope") # response = dashscope.MultiModalEmbedding.call( # model=model, input=input, api_key=api_key, kwargs=kwargs # ) # if response.status_code == HTTPStatus.OK: # return response.output["embedding"] # else: # logger.error("Calling MultiModalEmbedding failed, details: %s" % response) # return [] class XinferenceEmbedding(BaseEmbedding): """Xinference class for text embedding. """ model_description: Dict[str, Any] = Field( description="The model description from Xinference." ) _generator: Any = PrivateAttr() _model_uid: str = Field(description="The Xinference model to use.") _endpoint: str = Field(description="The Xinference endpoint URL to use.") def __init__( self, model_uid: str, endpoint: str, embed_batch_size: int = EMBED_MAX_BATCH_SIZE, dimensions: Optional[int] = None, additional_kwargs: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None, api_base: Optional[str] = None, api_version: Optional[str] = None, max_retries: int = 10, # timeout: float = 60.0, # reuse_client: bool = True, # callback_manager: Optional[CallbackManager] = None, # default_headers: Optional[Dict[str, str]] = None, # http_client: Optional[httpx.Client] = None, # async_http_client: Optional[httpx.AsyncClient] = None, # num_workers: Optional[int] = None, **kwargs: Any, ) -> None: generator, model_description = self.load_model( model_uid, endpoint ) self._generator = generator #self._model_uid = model_uid #self._endpoint = endpoint super().__init__( embed_batch_size=embed_batch_size, dimensions=dimensions, #callback_manager=callback_manager, model_name=model_uid, additional_kwargs=additional_kwargs, api_key=api_key, api_base=api_base, api_version=api_version, max_retries=max_retries, # reuse_client=reuse_client, # timeout=timeout, # default_headers=default_headers, # num_workers=num_workers, **kwargs, ) def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: try: from xinference.client import RESTfulClient except ImportError: raise ImportError( "Could not import Xinference library." 'Please install Xinference with `pip install "xinference[all]"`' ) client = RESTfulClient(endpoint) try: assert isinstance(client, RESTfulClient) except AssertionError: raise RuntimeError( "Could not create RESTfulClient instance." "Please make sure Xinference endpoint is running at the correct port." ) generator = client.get_model(model_uid) model_description = client.list_models()[model_uid] try: assert generator is not None assert model_description is not None except AssertionError: raise RuntimeError( "Could not get model from endpoint." "Please make sure Xinference endpoint is running at the correct port." ) model = model_description["model_name"] return generator, model_description @classmethod def class_name(cls) -> str: return "XinferenceEmbedding" def _get_text_embedding(self, text: str) -> Embedding: """ Embed the input text synchronously. Subclasses should implement this method. Reference get_text_embedding's docstring for more information. """ assert self._generator is not None response = self._generator.create_embedding(input=text) return response['data'][0]['embedding'] def _get_query_embedding(self, query: str) -> Embedding: """ Embed the input query synchronously. Subclasses should implement this method. Reference get_query_embedding's docstring for more information. """ return self._get_text_embedding(query) async def _aget_query_embedding(self, query: str) -> Embedding: """ Embed the input query asynchronously. Subclasses should implement this method. Reference get_query_embedding's docstring for more information. """ return self._get_query_embedding(query)