From 0f6d76ddbea8238419103bfa0f4f43d26ef9198f Mon Sep 17 00:00:00 2001 From: paituo <330435863@qq.com> Date: Mon, 19 Aug 2024 08:27:22 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0XinferenceRerank?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/xinference/base.py | 233 +++++++++++++++------------------ 1 file changed, 102 insertions(+), 131 deletions(-) diff --git a/backend/app/xinference/base.py b/backend/app/xinference/base.py index a16bcd7..d6bca82 100644 --- a/backend/app/xinference/base.py +++ b/backend/app/xinference/base.py @@ -7,147 +7,19 @@ from typing import Any, Dict, List, Optional, Union, Tuple from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding from llama_index.core.bridge.pydantic import PrivateAttr +from llama_index.core.callbacks import CBEventType, EventPayload from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding -from llama_index.core.schema import ImageType +from llama_index.core.postprocessor.types import BaseNodePostprocessor +from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle from pydantic import Field logger = logging.getLogger(__name__) -# class XinferenceTextEmbeddingType(str, Enum): -# """DashScope TextEmbedding text_type.""" -# -# TEXT_TYPE_QUERY = "query" -# TEXT_TYPE_DOCUMENT = "document" -# -# -# class DashScopeTextEmbeddingModels(str, Enum): -# """DashScope TextEmbedding models.""" -# -# TEXT_EMBEDDING_V1 = "text-embedding-v1" -# TEXT_EMBEDDING_V2 = "text-embedding-v2" -# TEXT_EMBEDDING_V3 = "text-embedding-v3" -# -# -# class DashScopeBatchTextEmbeddingModels(str, Enum): -# """DashScope TextEmbedding models.""" -# -# TEXT_EMBEDDING_ASYNC_V1 = "text-embedding-async-v1" -# TEXT_EMBEDDING_ASYNC_V2 = "text-embedding-async-v2" -# TEXT_EMBEDDING_ASYNC_V3 = "text-embedding-async-v3" - - EMBED_MAX_INPUT_LENGTH = 2048 EMBED_MAX_BATCH_SIZE = 1 -# class DashScopeMultiModalEmbeddingModels(str, Enum): -# """DashScope MultiModalEmbedding models.""" -# -# MULTIMODAL_EMBEDDING_ONE_PEACE_V1 = "multimodal-embedding-one-peace-v1" - - -# def get_text_embedding( -# model: str, -# text: Union[str, List[str]], -# api_key: Optional[str] = None, -# **kwargs: Any, -# ) -> List[List[float]]: -# """Call DashScope text embedding. -# ref: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details. -# -# Args: -# model (str): The `DashScopeTextEmbeddingModels` -# text (Union[str, List[str]]): text or list text to embedding. -# -# Raises: -# ImportError: need import dashscope -# -# Returns: -# List[List[float]]: The list of embedding result, if failed return empty list. -# if some of test no output, the correspond index of output is None. -# """ -# try: -# import dashscope -# except ImportError: -# raise ImportError("DashScope requires `pip install dashscope") -# if isinstance(text, str): -# text = [text] -# response = dashscope.TextEmbedding.call( -# model=model, input=text, api_key=api_key, kwargs=kwargs -# ) -# embedding_results = [None] * len(text) -# if response.status_code == HTTPStatus.OK: -# for emb in response.output["embeddings"]: -# embedding_results[emb["text_index"]] = emb["embedding"] -# else: -# logger.error("Calling TextEmbedding failed, details: %s" % response) -# -# return embedding_results -# -# -# def get_batch_text_embedding( -# model: str, url: str, api_key: Optional[str] = None, **kwargs: Any -# ) -> Optional[str]: -# """Call DashScope batch text embedding. -# -# Args: -# model (str): The `DashScopeMultiModalEmbeddingModels` -# url (str): The url of the file to embedding which with lines of text to embedding. -# -# Raises: -# ImportError: Need install dashscope package. -# -# Returns: -# str: The url of the embedding result, format ref: -# https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details -# """ -# try: -# import dashscope -# except ImportError: -# raise ImportError("DashScope requires `pip install dashscope") -# response = dashscope.BatchTextEmbedding.call( -# model=model, url=url, api_key=api_key, kwargs=kwargs -# ) -# if response.status_code == HTTPStatus.OK: -# return response.output["url"] -# else: -# logger.error("Calling BatchTextEmbedding failed, details: %s" % response) -# return None - - -# def get_multimodal_embedding( -# model: str, input: list, api_key: Optional[str] = None, **kwargs: Any -# ) -> List[float]: -# """Call DashScope multimodal embedding. -# ref: https://help.aliyun.com/zh/dashscope/developer-reference/one-peace-multimodal-embedding-api-details. -# -# Args: -# model (str): The `DashScopeBatchTextEmbeddingModels` -# input (str): The input of the embedding, eg: -# [{'factor': 1, 'text': '你好'}, -# {'factor': 2, 'audio': 'https://dashscope.oss-cn-beijing.aliyuncs.com/audios/cow.flac'}, -# {'factor': 3, 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png'}] -# -# Raises: -# ImportError: Need install dashscope package. -# -# Returns: -# List[float]: Embedding result, if failed return empty list. -# """ -# try: -# import dashscope -# except ImportError: -# raise ImportError("DashScope requires `pip install dashscope") -# response = dashscope.MultiModalEmbedding.call( -# model=model, input=input, api_key=api_key, kwargs=kwargs -# ) -# if response.status_code == HTTPStatus.OK: -# return response.output["embedding"] -# else: -# logger.error("Calling MultiModalEmbedding failed, details: %s" % response) -# return [] - class XinferenceEmbedding(BaseEmbedding): """Xinference class for text embedding. @@ -270,3 +142,102 @@ class XinferenceEmbedding(BaseEmbedding): docstring for more information. """ return self._get_query_embedding(query) + +class XinferenceRerank(BaseNodePostprocessor): + """Xinference class for rerank. + + """ + model_description: Dict[str, Any] = Field( + description="The model description from Xinference." + ) + _generator: Any = PrivateAttr() + _model_uid: str = Field(description="The Xinference model to use.") + _endpoint: str = Field(description="The Xinference endpoint URL to use.") + #model: str = Field(description="Dashscope rerank model name.") + top_n: int = Field(description="Top N nodes to return.") + + def __init__( + self, + model_uid: str, + endpoint: str, + top_n: int = 3, + return_documents: bool = False + ): + generator, model_description = self.load_model( + model_uid, endpoint + ) + self._generator = generator + super().__init__(top_n=top_n, model=model_uid, return_documents=return_documents) + + @classmethod + def class_name(cls) -> str: + return "XinferenceRerank" + + def _postprocess_nodes( + self, + nodes: List[NodeWithScore], + query_bundle: Optional[QueryBundle] = None, + ) -> List[NodeWithScore]: + if query_bundle is None: + raise ValueError("Missing query bundle in extra info.") + if len(nodes) == 0: + return [] + + with self.callback_manager.event( + CBEventType.RERANKING, + payload={ + EventPayload.NODES: nodes, + EventPayload.MODEL_NAME: self._model_uid, + EventPayload.QUERY_STR: query_bundle.query_str, + EventPayload.TOP_K: self.top_n, + }, + ) as event: + texts = [node.node.get_content() for node in nodes] + response = self._generator.rerank(texts,query_bundle.query_str) + new_nodes = [] + for result in response['results']: + new_node_with_score = NodeWithScore( + node=nodes[result['index']].node, score=result['relevance_score'] + ) + print(new_node_with_score.node.get_content) + print('\n') + print(new_node_with_score.score) + new_nodes.append(new_node_with_score) + event.on_end(payload={EventPayload.NODES: new_nodes}) + + return new_nodes + + def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: + try: + from xinference.client import RESTfulClient + except ImportError: + raise ImportError( + "Could not import Xinference library." + 'Please install Xinference with `pip install "xinference[all]"`' + ) + + client = RESTfulClient(endpoint) + + try: + assert isinstance(client, RESTfulClient) + except AssertionError: + raise RuntimeError( + "Could not create RESTfulClient instance." + "Please make sure Xinference endpoint is running at the correct port." + ) + + generator = client.get_model(model_uid) + model_description = client.list_models()[model_uid] + + try: + assert generator is not None + assert model_description is not None + except AssertionError: + raise RuntimeError( + "Could not get model from endpoint." + "Please make sure Xinference endpoint is running at the correct port." + ) + + model = model_description["model_name"] + + return generator, model_description \ No newline at end of file