From 092d9705a76d4a56cdbe55a9867948e54a5f91ea Mon Sep 17 00:00:00 2001 From: paituo <330435863@qq.com> Date: Wed, 14 Aug 2024 08:50:19 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AF=B9XinferenceEmbedding?= =?UTF-8?q?=E7=9A=84=E6=94=AF=E6=8C=81=EF=BC=8C=E4=B8=B4=E6=AD=BB=E6=94=BE?= =?UTF-8?q?=E5=88=B0=E8=BF=99=E9=87=8C=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/xinference/__init__.py | 0 backend/app/xinference/base.py | 272 +++++++++++++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 backend/app/xinference/__init__.py create mode 100644 backend/app/xinference/base.py diff --git a/backend/app/xinference/__init__.py b/backend/app/xinference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/backend/app/xinference/base.py b/backend/app/xinference/base.py new file mode 100644 index 0000000..a16bcd7 --- /dev/null +++ b/backend/app/xinference/base.py @@ -0,0 +1,272 @@ +"""Xinference embeddings file.""" + +import logging +from enum import Enum +from http import HTTPStatus +from typing import Any, Dict, List, Optional, Union, Tuple + +from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding +from llama_index.core.bridge.pydantic import PrivateAttr +from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding +from llama_index.core.schema import ImageType +from pydantic import Field + +logger = logging.getLogger(__name__) + + +# class XinferenceTextEmbeddingType(str, Enum): +# """DashScope TextEmbedding text_type.""" +# +# TEXT_TYPE_QUERY = "query" +# TEXT_TYPE_DOCUMENT = "document" +# +# +# class DashScopeTextEmbeddingModels(str, Enum): +# """DashScope TextEmbedding models.""" +# +# TEXT_EMBEDDING_V1 = "text-embedding-v1" +# TEXT_EMBEDDING_V2 = "text-embedding-v2" +# TEXT_EMBEDDING_V3 = "text-embedding-v3" +# +# +# class DashScopeBatchTextEmbeddingModels(str, Enum): +# """DashScope TextEmbedding models.""" +# +# TEXT_EMBEDDING_ASYNC_V1 = "text-embedding-async-v1" +# TEXT_EMBEDDING_ASYNC_V2 = "text-embedding-async-v2" +# TEXT_EMBEDDING_ASYNC_V3 = "text-embedding-async-v3" + + +EMBED_MAX_INPUT_LENGTH = 2048 +EMBED_MAX_BATCH_SIZE = 1 + + +# class DashScopeMultiModalEmbeddingModels(str, Enum): +# """DashScope MultiModalEmbedding models.""" +# +# MULTIMODAL_EMBEDDING_ONE_PEACE_V1 = "multimodal-embedding-one-peace-v1" + + +# def get_text_embedding( +# model: str, +# text: Union[str, List[str]], +# api_key: Optional[str] = None, +# **kwargs: Any, +# ) -> List[List[float]]: +# """Call DashScope text embedding. +# ref: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details. +# +# Args: +# model (str): The `DashScopeTextEmbeddingModels` +# text (Union[str, List[str]]): text or list text to embedding. +# +# Raises: +# ImportError: need import dashscope +# +# Returns: +# List[List[float]]: The list of embedding result, if failed return empty list. +# if some of test no output, the correspond index of output is None. +# """ +# try: +# import dashscope +# except ImportError: +# raise ImportError("DashScope requires `pip install dashscope") +# if isinstance(text, str): +# text = [text] +# response = dashscope.TextEmbedding.call( +# model=model, input=text, api_key=api_key, kwargs=kwargs +# ) +# embedding_results = [None] * len(text) +# if response.status_code == HTTPStatus.OK: +# for emb in response.output["embeddings"]: +# embedding_results[emb["text_index"]] = emb["embedding"] +# else: +# logger.error("Calling TextEmbedding failed, details: %s" % response) +# +# return embedding_results +# +# +# def get_batch_text_embedding( +# model: str, url: str, api_key: Optional[str] = None, **kwargs: Any +# ) -> Optional[str]: +# """Call DashScope batch text embedding. +# +# Args: +# model (str): The `DashScopeMultiModalEmbeddingModels` +# url (str): The url of the file to embedding which with lines of text to embedding. +# +# Raises: +# ImportError: Need install dashscope package. +# +# Returns: +# str: The url of the embedding result, format ref: +# https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details +# """ +# try: +# import dashscope +# except ImportError: +# raise ImportError("DashScope requires `pip install dashscope") +# response = dashscope.BatchTextEmbedding.call( +# model=model, url=url, api_key=api_key, kwargs=kwargs +# ) +# if response.status_code == HTTPStatus.OK: +# return response.output["url"] +# else: +# logger.error("Calling BatchTextEmbedding failed, details: %s" % response) +# return None + + +# def get_multimodal_embedding( +# model: str, input: list, api_key: Optional[str] = None, **kwargs: Any +# ) -> List[float]: +# """Call DashScope multimodal embedding. +# ref: https://help.aliyun.com/zh/dashscope/developer-reference/one-peace-multimodal-embedding-api-details. +# +# Args: +# model (str): The `DashScopeBatchTextEmbeddingModels` +# input (str): The input of the embedding, eg: +# [{'factor': 1, 'text': '你好'}, +# {'factor': 2, 'audio': 'https://dashscope.oss-cn-beijing.aliyuncs.com/audios/cow.flac'}, +# {'factor': 3, 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png'}] +# +# Raises: +# ImportError: Need install dashscope package. +# +# Returns: +# List[float]: Embedding result, if failed return empty list. +# """ +# try: +# import dashscope +# except ImportError: +# raise ImportError("DashScope requires `pip install dashscope") +# response = dashscope.MultiModalEmbedding.call( +# model=model, input=input, api_key=api_key, kwargs=kwargs +# ) +# if response.status_code == HTTPStatus.OK: +# return response.output["embedding"] +# else: +# logger.error("Calling MultiModalEmbedding failed, details: %s" % response) +# return [] + +class XinferenceEmbedding(BaseEmbedding): + """Xinference class for text embedding. + + """ + model_description: Dict[str, Any] = Field( + description="The model description from Xinference." + ) + _generator: Any = PrivateAttr() + _model_uid: str = Field(description="The Xinference model to use.") + _endpoint: str = Field(description="The Xinference endpoint URL to use.") + + def __init__( + self, + model_uid: str, + endpoint: str, + embed_batch_size: int = EMBED_MAX_BATCH_SIZE, + dimensions: Optional[int] = None, + additional_kwargs: Optional[Dict[str, Any]] = None, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + max_retries: int = 10, + # timeout: float = 60.0, + # reuse_client: bool = True, + # callback_manager: Optional[CallbackManager] = None, + # default_headers: Optional[Dict[str, str]] = None, + # http_client: Optional[httpx.Client] = None, + # async_http_client: Optional[httpx.AsyncClient] = None, + # num_workers: Optional[int] = None, + **kwargs: Any, + ) -> None: + generator, model_description = self.load_model( + model_uid, endpoint + ) + self._generator = generator + #self._model_uid = model_uid + #self._endpoint = endpoint + super().__init__( + embed_batch_size=embed_batch_size, + dimensions=dimensions, + #callback_manager=callback_manager, + model_name=model_uid, + additional_kwargs=additional_kwargs, + api_key=api_key, + api_base=api_base, + api_version=api_version, + max_retries=max_retries, + # reuse_client=reuse_client, + # timeout=timeout, + # default_headers=default_headers, + # num_workers=num_workers, + **kwargs, + ) + + def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: + try: + from xinference.client import RESTfulClient + except ImportError: + raise ImportError( + "Could not import Xinference library." + 'Please install Xinference with `pip install "xinference[all]"`' + ) + + client = RESTfulClient(endpoint) + + try: + assert isinstance(client, RESTfulClient) + except AssertionError: + raise RuntimeError( + "Could not create RESTfulClient instance." + "Please make sure Xinference endpoint is running at the correct port." + ) + + generator = client.get_model(model_uid) + model_description = client.list_models()[model_uid] + + try: + assert generator is not None + assert model_description is not None + except AssertionError: + raise RuntimeError( + "Could not get model from endpoint." + "Please make sure Xinference endpoint is running at the correct port." + ) + + model = model_description["model_name"] + + return generator, model_description + + @classmethod + def class_name(cls) -> str: + return "XinferenceEmbedding" + + def _get_text_embedding(self, text: str) -> Embedding: + """ + Embed the input text synchronously. + + Subclasses should implement this method. Reference get_text_embedding's + docstring for more information. + """ + assert self._generator is not None + + response = self._generator.create_embedding(input=text) + return response['data'][0]['embedding'] + + def _get_query_embedding(self, query: str) -> Embedding: + """ + Embed the input query synchronously. + + Subclasses should implement this method. Reference get_query_embedding's + docstring for more information. + """ + return self._get_text_embedding(query) + + async def _aget_query_embedding(self, query: str) -> Embedding: + """ + Embed the input query asynchronously. + + Subclasses should implement this method. Reference get_query_embedding's + docstring for more information. + """ + return self._get_query_embedding(query)