dev #1
+102
-131
@@ -7,147 +7,19 @@ from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
|
||||
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||
from llama_index.core.callbacks import CBEventType, EventPayload
|
||||
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
|
||||
from llama_index.core.schema import ImageType
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
|
||||
from pydantic import Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# class XinferenceTextEmbeddingType(str, Enum):
|
||||
# """DashScope TextEmbedding text_type."""
|
||||
#
|
||||
# TEXT_TYPE_QUERY = "query"
|
||||
# TEXT_TYPE_DOCUMENT = "document"
|
||||
#
|
||||
#
|
||||
# class DashScopeTextEmbeddingModels(str, Enum):
|
||||
# """DashScope TextEmbedding models."""
|
||||
#
|
||||
# TEXT_EMBEDDING_V1 = "text-embedding-v1"
|
||||
# TEXT_EMBEDDING_V2 = "text-embedding-v2"
|
||||
# TEXT_EMBEDDING_V3 = "text-embedding-v3"
|
||||
#
|
||||
#
|
||||
# class DashScopeBatchTextEmbeddingModels(str, Enum):
|
||||
# """DashScope TextEmbedding models."""
|
||||
#
|
||||
# TEXT_EMBEDDING_ASYNC_V1 = "text-embedding-async-v1"
|
||||
# TEXT_EMBEDDING_ASYNC_V2 = "text-embedding-async-v2"
|
||||
# TEXT_EMBEDDING_ASYNC_V3 = "text-embedding-async-v3"
|
||||
|
||||
|
||||
EMBED_MAX_INPUT_LENGTH = 2048
|
||||
EMBED_MAX_BATCH_SIZE = 1
|
||||
|
||||
|
||||
# class DashScopeMultiModalEmbeddingModels(str, Enum):
|
||||
# """DashScope MultiModalEmbedding models."""
|
||||
#
|
||||
# MULTIMODAL_EMBEDDING_ONE_PEACE_V1 = "multimodal-embedding-one-peace-v1"
|
||||
|
||||
|
||||
# def get_text_embedding(
|
||||
# model: str,
|
||||
# text: Union[str, List[str]],
|
||||
# api_key: Optional[str] = None,
|
||||
# **kwargs: Any,
|
||||
# ) -> List[List[float]]:
|
||||
# """Call DashScope text embedding.
|
||||
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details.
|
||||
#
|
||||
# Args:
|
||||
# model (str): The `DashScopeTextEmbeddingModels`
|
||||
# text (Union[str, List[str]]): text or list text to embedding.
|
||||
#
|
||||
# Raises:
|
||||
# ImportError: need import dashscope
|
||||
#
|
||||
# Returns:
|
||||
# List[List[float]]: The list of embedding result, if failed return empty list.
|
||||
# if some of test no output, the correspond index of output is None.
|
||||
# """
|
||||
# try:
|
||||
# import dashscope
|
||||
# except ImportError:
|
||||
# raise ImportError("DashScope requires `pip install dashscope")
|
||||
# if isinstance(text, str):
|
||||
# text = [text]
|
||||
# response = dashscope.TextEmbedding.call(
|
||||
# model=model, input=text, api_key=api_key, kwargs=kwargs
|
||||
# )
|
||||
# embedding_results = [None] * len(text)
|
||||
# if response.status_code == HTTPStatus.OK:
|
||||
# for emb in response.output["embeddings"]:
|
||||
# embedding_results[emb["text_index"]] = emb["embedding"]
|
||||
# else:
|
||||
# logger.error("Calling TextEmbedding failed, details: %s" % response)
|
||||
#
|
||||
# return embedding_results
|
||||
#
|
||||
#
|
||||
# def get_batch_text_embedding(
|
||||
# model: str, url: str, api_key: Optional[str] = None, **kwargs: Any
|
||||
# ) -> Optional[str]:
|
||||
# """Call DashScope batch text embedding.
|
||||
#
|
||||
# Args:
|
||||
# model (str): The `DashScopeMultiModalEmbeddingModels`
|
||||
# url (str): The url of the file to embedding which with lines of text to embedding.
|
||||
#
|
||||
# Raises:
|
||||
# ImportError: Need install dashscope package.
|
||||
#
|
||||
# Returns:
|
||||
# str: The url of the embedding result, format ref:
|
||||
# https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details
|
||||
# """
|
||||
# try:
|
||||
# import dashscope
|
||||
# except ImportError:
|
||||
# raise ImportError("DashScope requires `pip install dashscope")
|
||||
# response = dashscope.BatchTextEmbedding.call(
|
||||
# model=model, url=url, api_key=api_key, kwargs=kwargs
|
||||
# )
|
||||
# if response.status_code == HTTPStatus.OK:
|
||||
# return response.output["url"]
|
||||
# else:
|
||||
# logger.error("Calling BatchTextEmbedding failed, details: %s" % response)
|
||||
# return None
|
||||
|
||||
|
||||
# def get_multimodal_embedding(
|
||||
# model: str, input: list, api_key: Optional[str] = None, **kwargs: Any
|
||||
# ) -> List[float]:
|
||||
# """Call DashScope multimodal embedding.
|
||||
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/one-peace-multimodal-embedding-api-details.
|
||||
#
|
||||
# Args:
|
||||
# model (str): The `DashScopeBatchTextEmbeddingModels`
|
||||
# input (str): The input of the embedding, eg:
|
||||
# [{'factor': 1, 'text': '你好'},
|
||||
# {'factor': 2, 'audio': 'https://dashscope.oss-cn-beijing.aliyuncs.com/audios/cow.flac'},
|
||||
# {'factor': 3, 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png'}]
|
||||
#
|
||||
# Raises:
|
||||
# ImportError: Need install dashscope package.
|
||||
#
|
||||
# Returns:
|
||||
# List[float]: Embedding result, if failed return empty list.
|
||||
# """
|
||||
# try:
|
||||
# import dashscope
|
||||
# except ImportError:
|
||||
# raise ImportError("DashScope requires `pip install dashscope")
|
||||
# response = dashscope.MultiModalEmbedding.call(
|
||||
# model=model, input=input, api_key=api_key, kwargs=kwargs
|
||||
# )
|
||||
# if response.status_code == HTTPStatus.OK:
|
||||
# return response.output["embedding"]
|
||||
# else:
|
||||
# logger.error("Calling MultiModalEmbedding failed, details: %s" % response)
|
||||
# return []
|
||||
|
||||
class XinferenceEmbedding(BaseEmbedding):
|
||||
"""Xinference class for text embedding.
|
||||
|
||||
@@ -270,3 +142,102 @@ class XinferenceEmbedding(BaseEmbedding):
|
||||
docstring for more information.
|
||||
"""
|
||||
return self._get_query_embedding(query)
|
||||
|
||||
class XinferenceRerank(BaseNodePostprocessor):
|
||||
"""Xinference class for rerank.
|
||||
|
||||
"""
|
||||
model_description: Dict[str, Any] = Field(
|
||||
description="The model description from Xinference."
|
||||
)
|
||||
_generator: Any = PrivateAttr()
|
||||
_model_uid: str = Field(description="The Xinference model to use.")
|
||||
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||
#model: str = Field(description="Dashscope rerank model name.")
|
||||
top_n: int = Field(description="Top N nodes to return.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_uid: str,
|
||||
endpoint: str,
|
||||
top_n: int = 3,
|
||||
return_documents: bool = False
|
||||
):
|
||||
generator, model_description = self.load_model(
|
||||
model_uid, endpoint
|
||||
)
|
||||
self._generator = generator
|
||||
super().__init__(top_n=top_n, model=model_uid, return_documents=return_documents)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "XinferenceRerank"
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: List[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> List[NodeWithScore]:
|
||||
if query_bundle is None:
|
||||
raise ValueError("Missing query bundle in extra info.")
|
||||
if len(nodes) == 0:
|
||||
return []
|
||||
|
||||
with self.callback_manager.event(
|
||||
CBEventType.RERANKING,
|
||||
payload={
|
||||
EventPayload.NODES: nodes,
|
||||
EventPayload.MODEL_NAME: self._model_uid,
|
||||
EventPayload.QUERY_STR: query_bundle.query_str,
|
||||
EventPayload.TOP_K: self.top_n,
|
||||
},
|
||||
) as event:
|
||||
texts = [node.node.get_content() for node in nodes]
|
||||
response = self._generator.rerank(texts,query_bundle.query_str)
|
||||
new_nodes = []
|
||||
for result in response['results']:
|
||||
new_node_with_score = NodeWithScore(
|
||||
node=nodes[result['index']].node, score=result['relevance_score']
|
||||
)
|
||||
print(new_node_with_score.node.get_content)
|
||||
print('\n')
|
||||
print(new_node_with_score.score)
|
||||
new_nodes.append(new_node_with_score)
|
||||
event.on_end(payload={EventPayload.NODES: new_nodes})
|
||||
|
||||
return new_nodes
|
||||
|
||||
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Xinference library."
|
||||
'Please install Xinference with `pip install "xinference[all]"`'
|
||||
)
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
try:
|
||||
assert isinstance(client, RESTfulClient)
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not create RESTfulClient instance."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
generator = client.get_model(model_uid)
|
||||
model_description = client.list_models()[model_uid]
|
||||
|
||||
try:
|
||||
assert generator is not None
|
||||
assert model_description is not None
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not get model from endpoint."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
model = model_description["model_name"]
|
||||
|
||||
return generator, model_description
|
||||
Reference in New Issue
Block a user