From 95fbc820b9ceac21341c9717a4d41092eda2bf62 Mon Sep 17 00:00:00 2001 From: paituo <330435863@qq.com> Date: Tue, 10 Sep 2024 08:38:57 +0800 Subject: [PATCH] =?UTF-8?q?=E7=94=B1=E4=BA=8Exinference=200.15.0=20?= =?UTF-8?q?=E6=9C=80=E6=96=B0=E7=89=88=E5=B7=B2=E8=A7=A3=E5=86=B3=E5=92=8C?= =?UTF-8?q?llamaindex=200.2.0=E7=89=88=E6=9C=AC=E7=9A=84=E5=86=B2=E7=AA=81?= =?UTF-8?q?=E9=97=AE=E9=A2=98=EF=BC=8C=E6=89=80=E4=BB=A5=E6=97=A0=E9=9C=80?= =?UTF-8?q?=E8=87=AA=E5=B7=B1=E5=AE=9E=E7=8E=B0embeddings=E5=92=8Crerank?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/xinference/__init__.py | 0 backend/app/xinference/base.py | 272 ----------------------------- 2 files changed, 272 deletions(-) delete mode 100644 backend/app/xinference/__init__.py delete mode 100644 backend/app/xinference/base.py diff --git a/backend/app/xinference/__init__.py b/backend/app/xinference/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/backend/app/xinference/base.py b/backend/app/xinference/base.py deleted file mode 100644 index f256ec8..0000000 --- a/backend/app/xinference/base.py +++ /dev/null @@ -1,272 +0,0 @@ -"""Xinference embeddings file.""" - -import logging -from enum import Enum -from http import HTTPStatus -from typing import Any, Dict, List, Optional, Union, Tuple - -from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding, dispatcher -from llama_index.core.bridge.pydantic import PrivateAttr -from llama_index.core.callbacks import CBEventType, EventPayload -from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding -from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReRankEndEvent -from llama_index.core.postprocessor.types import BaseNodePostprocessor -from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle -from pydantic import Field - -logger = logging.getLogger(__name__) - - -EMBED_MAX_INPUT_LENGTH = 2048 -EMBED_MAX_BATCH_SIZE = 1 - - -class XinferenceEmbedding(BaseEmbedding): - """Xinference class for text embedding. - - """ - model_description: Dict[str, Any] = Field( - description="The model description from Xinference." - ) - _generator: Any = PrivateAttr() - _model_uid: str = Field(description="The Xinference model to use.") - _endpoint: str = Field(description="The Xinference endpoint URL to use.") - - def __init__( - self, - model_uid: str, - endpoint: str, - embed_batch_size: int = EMBED_MAX_BATCH_SIZE, - dimensions: Optional[int] = None, - additional_kwargs: Optional[Dict[str, Any]] = None, - api_key: Optional[str] = None, - api_base: Optional[str] = None, - api_version: Optional[str] = None, - max_retries: int = 10, - # timeout: float = 60.0, - # reuse_client: bool = True, - # callback_manager: Optional[CallbackManager] = None, - # default_headers: Optional[Dict[str, str]] = None, - # http_client: Optional[httpx.Client] = None, - # async_http_client: Optional[httpx.AsyncClient] = None, - # num_workers: Optional[int] = None, - **kwargs: Any, - ) -> None: - generator, model_description, embed_batch_size, dimensions = self.load_model( - model_uid, endpoint - ) - self._generator = generator - #self._model_uid = model_uid - #self._endpoint = endpoint - super().__init__( - embed_batch_size=embed_batch_size, - dimensions=dimensions, - #callback_manager=callback_manager, - model_name=model_uid, - additional_kwargs=additional_kwargs, - api_key=api_key, - api_base=api_base, - api_version=api_version, - max_retries=max_retries, - # reuse_client=reuse_client, - # timeout=timeout, - # default_headers=default_headers, - # num_workers=num_workers, - **kwargs, - ) - - def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: - try: - from xinference.client import RESTfulClient - except ImportError: - raise ImportError( - "Could not import Xinference library." - 'Please install Xinference with `pip install "xinference[all]"`' - ) - - client = RESTfulClient(endpoint) - - try: - assert isinstance(client, RESTfulClient) - except AssertionError: - raise RuntimeError( - "Could not create RESTfulClient instance." - "Please make sure Xinference endpoint is running at the correct port." - ) - - generator = client.get_model(model_uid) - model_description = client.list_models()[model_uid] - - try: - assert generator is not None - assert model_description is not None - except AssertionError: - raise RuntimeError( - "Could not get model from endpoint." - "Please make sure Xinference endpoint is running at the correct port." - ) - - model = model_description["model_name"] - replica = model_description['replica'] - dimensions = model_description['dimensions'] - max_tokens = model_description['max_tokens'] - - return generator, model_description, replica, dimensions - - @classmethod - def class_name(cls) -> str: - return "XinferenceEmbedding" - - def _get_text_embedding(self, text: str) -> Embedding: - """ - Embed the input text synchronously. - - Subclasses should implement this method. Reference get_text_embedding's - docstring for more information. - """ - assert self._generator is not None - - response = self._generator.create_embedding(input=text) - return response['data'][0]['embedding'] - - def _get_query_embedding(self, query: str) -> Embedding: - """ - Embed the input query synchronously. - - Subclasses should implement this method. Reference get_query_embedding's - docstring for more information. - """ - return self._get_text_embedding(query) - - async def _aget_query_embedding(self, query: str) -> Embedding: - """ - Embed the input query asynchronously. - - Subclasses should implement this method. Reference get_query_embedding's - docstring for more information. - """ - return self._get_query_embedding(query) - -class XinferenceRerank(BaseNodePostprocessor): - """Xinference class for rerank. - - """ - model_description: Dict[str, Any] = Field( - description="The model description from Xinference." - ) - _generator: Any = PrivateAttr() - _model_uid: str = Field(description="The Xinference model to use.") - _endpoint: str = Field(description="The Xinference endpoint URL to use.") - model: str = Field(description="Dashscope rerank model name.") - top_n: int = Field(description="Top N nodes to return.") - threshold: float = Field(description="threshold nodes to return.") - - def __init__( - self, - model_uid: str, - endpoint: str, - top_n: int = None, - threshold: float = None, - return_documents: bool = False - ): - _model_uid = model_uid - _endpoint = endpoint - _op_n = top_n - threshold = threshold - generator, model_description = self.load_model( - model_uid, endpoint - ) - self._generator = generator - super().__init__(top_n=top_n, model=model_uid, model_uid=model_uid, threshold = threshold, return_documents=return_documents) - - @classmethod - def class_name(cls) -> str: - return "XinferenceRerank" - - def _postprocess_nodes( - self, - nodes: List[NodeWithScore], - query_bundle: Optional[QueryBundle] = None, - ) -> List[NodeWithScore]: - if query_bundle is None: - raise ValueError("Missing query bundle in extra info.") - if len(nodes) == 0: - return [] - - dispatcher.event( - ReRankStartEvent( - nodes = nodes, - top_n = self.top_n, - query = query_bundle, - model_name = self.model - ) - ) - - with self.callback_manager.event( - CBEventType.RERANKING, - payload={ - EventPayload.NODES: nodes, - EventPayload.MODEL_NAME: self._model_uid, - EventPayload.QUERY_STR: query_bundle.query_str, - EventPayload.TOP_K: self.top_n, - }, - ) as event: - texts = [node.node.get_content() for node in nodes] - response = self._generator.rerank(texts,query_bundle.query_str) - new_nodes = [] - for result in response['results']: - new_node_with_score = NodeWithScore( - node=nodes[result['index']].node, score=result['relevance_score'] - ) - if self.threshold is not None: - if new_node_with_score.score >=self.threshold: - new_nodes.append(new_node_with_score) - - if self.top_n is not None: - if len(new_nodes) > self.top_n: - for index in new_nodes[self.top_n:-1]: - new_nodes.remove(index) - - event.on_end(payload={EventPayload.NODES: new_nodes}) - - dispatcher.event( - ReRankEndEvent( - nodes= new_nodes - ) - ) - return new_nodes - - def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: - try: - from xinference.client import RESTfulClient - except ImportError: - raise ImportError( - "Could not import Xinference library." - 'Please install Xinference with `pip install "xinference[all]"`' - ) - - client = RESTfulClient(endpoint) - - try: - assert isinstance(client, RESTfulClient) - except AssertionError: - raise RuntimeError( - "Could not create RESTfulClient instance." - "Please make sure Xinference endpoint is running at the correct port." - ) - - generator = client.get_model(model_uid) - model_description = client.list_models()[model_uid] - - try: - assert generator is not None - assert model_description is not None - except AssertionError: - raise RuntimeError( - "Could not get model from endpoint." - "Please make sure Xinference endpoint is running at the correct port." - ) - - model = model_description["model_name"] - - return generator, model_description \ No newline at end of file