增加对XinferenceEmbedding的支持,临死放到这里。
This commit is contained in:
@@ -0,0 +1,272 @@
|
|||||||
|
"""Xinference embeddings file."""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from enum import Enum
|
||||||
|
from http import HTTPStatus
|
||||||
|
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||||
|
|
||||||
|
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
|
||||||
|
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||||
|
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
|
||||||
|
from llama_index.core.schema import ImageType
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# class XinferenceTextEmbeddingType(str, Enum):
|
||||||
|
# """DashScope TextEmbedding text_type."""
|
||||||
|
#
|
||||||
|
# TEXT_TYPE_QUERY = "query"
|
||||||
|
# TEXT_TYPE_DOCUMENT = "document"
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class DashScopeTextEmbeddingModels(str, Enum):
|
||||||
|
# """DashScope TextEmbedding models."""
|
||||||
|
#
|
||||||
|
# TEXT_EMBEDDING_V1 = "text-embedding-v1"
|
||||||
|
# TEXT_EMBEDDING_V2 = "text-embedding-v2"
|
||||||
|
# TEXT_EMBEDDING_V3 = "text-embedding-v3"
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# class DashScopeBatchTextEmbeddingModels(str, Enum):
|
||||||
|
# """DashScope TextEmbedding models."""
|
||||||
|
#
|
||||||
|
# TEXT_EMBEDDING_ASYNC_V1 = "text-embedding-async-v1"
|
||||||
|
# TEXT_EMBEDDING_ASYNC_V2 = "text-embedding-async-v2"
|
||||||
|
# TEXT_EMBEDDING_ASYNC_V3 = "text-embedding-async-v3"
|
||||||
|
|
||||||
|
|
||||||
|
EMBED_MAX_INPUT_LENGTH = 2048
|
||||||
|
EMBED_MAX_BATCH_SIZE = 1
|
||||||
|
|
||||||
|
|
||||||
|
# class DashScopeMultiModalEmbeddingModels(str, Enum):
|
||||||
|
# """DashScope MultiModalEmbedding models."""
|
||||||
|
#
|
||||||
|
# MULTIMODAL_EMBEDDING_ONE_PEACE_V1 = "multimodal-embedding-one-peace-v1"
|
||||||
|
|
||||||
|
|
||||||
|
# def get_text_embedding(
|
||||||
|
# model: str,
|
||||||
|
# text: Union[str, List[str]],
|
||||||
|
# api_key: Optional[str] = None,
|
||||||
|
# **kwargs: Any,
|
||||||
|
# ) -> List[List[float]]:
|
||||||
|
# """Call DashScope text embedding.
|
||||||
|
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details.
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# model (str): The `DashScopeTextEmbeddingModels`
|
||||||
|
# text (Union[str, List[str]]): text or list text to embedding.
|
||||||
|
#
|
||||||
|
# Raises:
|
||||||
|
# ImportError: need import dashscope
|
||||||
|
#
|
||||||
|
# Returns:
|
||||||
|
# List[List[float]]: The list of embedding result, if failed return empty list.
|
||||||
|
# if some of test no output, the correspond index of output is None.
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# import dashscope
|
||||||
|
# except ImportError:
|
||||||
|
# raise ImportError("DashScope requires `pip install dashscope")
|
||||||
|
# if isinstance(text, str):
|
||||||
|
# text = [text]
|
||||||
|
# response = dashscope.TextEmbedding.call(
|
||||||
|
# model=model, input=text, api_key=api_key, kwargs=kwargs
|
||||||
|
# )
|
||||||
|
# embedding_results = [None] * len(text)
|
||||||
|
# if response.status_code == HTTPStatus.OK:
|
||||||
|
# for emb in response.output["embeddings"]:
|
||||||
|
# embedding_results[emb["text_index"]] = emb["embedding"]
|
||||||
|
# else:
|
||||||
|
# logger.error("Calling TextEmbedding failed, details: %s" % response)
|
||||||
|
#
|
||||||
|
# return embedding_results
|
||||||
|
#
|
||||||
|
#
|
||||||
|
# def get_batch_text_embedding(
|
||||||
|
# model: str, url: str, api_key: Optional[str] = None, **kwargs: Any
|
||||||
|
# ) -> Optional[str]:
|
||||||
|
# """Call DashScope batch text embedding.
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# model (str): The `DashScopeMultiModalEmbeddingModels`
|
||||||
|
# url (str): The url of the file to embedding which with lines of text to embedding.
|
||||||
|
#
|
||||||
|
# Raises:
|
||||||
|
# ImportError: Need install dashscope package.
|
||||||
|
#
|
||||||
|
# Returns:
|
||||||
|
# str: The url of the embedding result, format ref:
|
||||||
|
# https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# import dashscope
|
||||||
|
# except ImportError:
|
||||||
|
# raise ImportError("DashScope requires `pip install dashscope")
|
||||||
|
# response = dashscope.BatchTextEmbedding.call(
|
||||||
|
# model=model, url=url, api_key=api_key, kwargs=kwargs
|
||||||
|
# )
|
||||||
|
# if response.status_code == HTTPStatus.OK:
|
||||||
|
# return response.output["url"]
|
||||||
|
# else:
|
||||||
|
# logger.error("Calling BatchTextEmbedding failed, details: %s" % response)
|
||||||
|
# return None
|
||||||
|
|
||||||
|
|
||||||
|
# def get_multimodal_embedding(
|
||||||
|
# model: str, input: list, api_key: Optional[str] = None, **kwargs: Any
|
||||||
|
# ) -> List[float]:
|
||||||
|
# """Call DashScope multimodal embedding.
|
||||||
|
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/one-peace-multimodal-embedding-api-details.
|
||||||
|
#
|
||||||
|
# Args:
|
||||||
|
# model (str): The `DashScopeBatchTextEmbeddingModels`
|
||||||
|
# input (str): The input of the embedding, eg:
|
||||||
|
# [{'factor': 1, 'text': '你好'},
|
||||||
|
# {'factor': 2, 'audio': 'https://dashscope.oss-cn-beijing.aliyuncs.com/audios/cow.flac'},
|
||||||
|
# {'factor': 3, 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png'}]
|
||||||
|
#
|
||||||
|
# Raises:
|
||||||
|
# ImportError: Need install dashscope package.
|
||||||
|
#
|
||||||
|
# Returns:
|
||||||
|
# List[float]: Embedding result, if failed return empty list.
|
||||||
|
# """
|
||||||
|
# try:
|
||||||
|
# import dashscope
|
||||||
|
# except ImportError:
|
||||||
|
# raise ImportError("DashScope requires `pip install dashscope")
|
||||||
|
# response = dashscope.MultiModalEmbedding.call(
|
||||||
|
# model=model, input=input, api_key=api_key, kwargs=kwargs
|
||||||
|
# )
|
||||||
|
# if response.status_code == HTTPStatus.OK:
|
||||||
|
# return response.output["embedding"]
|
||||||
|
# else:
|
||||||
|
# logger.error("Calling MultiModalEmbedding failed, details: %s" % response)
|
||||||
|
# return []
|
||||||
|
|
||||||
|
class XinferenceEmbedding(BaseEmbedding):
|
||||||
|
"""Xinference class for text embedding.
|
||||||
|
|
||||||
|
"""
|
||||||
|
model_description: Dict[str, Any] = Field(
|
||||||
|
description="The model description from Xinference."
|
||||||
|
)
|
||||||
|
_generator: Any = PrivateAttr()
|
||||||
|
_model_uid: str = Field(description="The Xinference model to use.")
|
||||||
|
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_uid: str,
|
||||||
|
endpoint: str,
|
||||||
|
embed_batch_size: int = EMBED_MAX_BATCH_SIZE,
|
||||||
|
dimensions: Optional[int] = None,
|
||||||
|
additional_kwargs: Optional[Dict[str, Any]] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
max_retries: int = 10,
|
||||||
|
# timeout: float = 60.0,
|
||||||
|
# reuse_client: bool = True,
|
||||||
|
# callback_manager: Optional[CallbackManager] = None,
|
||||||
|
# default_headers: Optional[Dict[str, str]] = None,
|
||||||
|
# http_client: Optional[httpx.Client] = None,
|
||||||
|
# async_http_client: Optional[httpx.AsyncClient] = None,
|
||||||
|
# num_workers: Optional[int] = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
generator, model_description = self.load_model(
|
||||||
|
model_uid, endpoint
|
||||||
|
)
|
||||||
|
self._generator = generator
|
||||||
|
#self._model_uid = model_uid
|
||||||
|
#self._endpoint = endpoint
|
||||||
|
super().__init__(
|
||||||
|
embed_batch_size=embed_batch_size,
|
||||||
|
dimensions=dimensions,
|
||||||
|
#callback_manager=callback_manager,
|
||||||
|
model_name=model_uid,
|
||||||
|
additional_kwargs=additional_kwargs,
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
max_retries=max_retries,
|
||||||
|
# reuse_client=reuse_client,
|
||||||
|
# timeout=timeout,
|
||||||
|
# default_headers=default_headers,
|
||||||
|
# num_workers=num_workers,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||||
|
try:
|
||||||
|
from xinference.client import RESTfulClient
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import Xinference library."
|
||||||
|
'Please install Xinference with `pip install "xinference[all]"`'
|
||||||
|
)
|
||||||
|
|
||||||
|
client = RESTfulClient(endpoint)
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert isinstance(client, RESTfulClient)
|
||||||
|
except AssertionError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not create RESTfulClient instance."
|
||||||
|
"Please make sure Xinference endpoint is running at the correct port."
|
||||||
|
)
|
||||||
|
|
||||||
|
generator = client.get_model(model_uid)
|
||||||
|
model_description = client.list_models()[model_uid]
|
||||||
|
|
||||||
|
try:
|
||||||
|
assert generator is not None
|
||||||
|
assert model_description is not None
|
||||||
|
except AssertionError:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Could not get model from endpoint."
|
||||||
|
"Please make sure Xinference endpoint is running at the correct port."
|
||||||
|
)
|
||||||
|
|
||||||
|
model = model_description["model_name"]
|
||||||
|
|
||||||
|
return generator, model_description
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def class_name(cls) -> str:
|
||||||
|
return "XinferenceEmbedding"
|
||||||
|
|
||||||
|
def _get_text_embedding(self, text: str) -> Embedding:
|
||||||
|
"""
|
||||||
|
Embed the input text synchronously.
|
||||||
|
|
||||||
|
Subclasses should implement this method. Reference get_text_embedding's
|
||||||
|
docstring for more information.
|
||||||
|
"""
|
||||||
|
assert self._generator is not None
|
||||||
|
|
||||||
|
response = self._generator.create_embedding(input=text)
|
||||||
|
return response['data'][0]['embedding']
|
||||||
|
|
||||||
|
def _get_query_embedding(self, query: str) -> Embedding:
|
||||||
|
"""
|
||||||
|
Embed the input query synchronously.
|
||||||
|
|
||||||
|
Subclasses should implement this method. Reference get_query_embedding's
|
||||||
|
docstring for more information.
|
||||||
|
"""
|
||||||
|
return self._get_text_embedding(query)
|
||||||
|
|
||||||
|
async def _aget_query_embedding(self, query: str) -> Embedding:
|
||||||
|
"""
|
||||||
|
Embed the input query asynchronously.
|
||||||
|
|
||||||
|
Subclasses should implement this method. Reference get_query_embedding's
|
||||||
|
docstring for more information.
|
||||||
|
"""
|
||||||
|
return self._get_query_embedding(query)
|
||||||
Reference in New Issue
Block a user