优化了提示词
This commit is contained in:
@@ -0,0 +1,272 @@
|
||||
"""Xinference embeddings file."""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding, dispatcher
|
||||
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||
from llama_index.core.callbacks import CBEventType, EventPayload
|
||||
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
|
||||
from llama_index.core.instrumentation.events.rerank import ReRankStartEvent, ReRankEndEvent
|
||||
from llama_index.core.postprocessor.types import BaseNodePostprocessor
|
||||
from llama_index.core.schema import ImageType, NodeWithScore, QueryBundle
|
||||
from pydantic import Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
EMBED_MAX_INPUT_LENGTH = 2048
|
||||
EMBED_MAX_BATCH_SIZE = 1
|
||||
|
||||
|
||||
class XinferenceEmbedding(BaseEmbedding):
|
||||
"""Xinference class for text embedding.
|
||||
|
||||
"""
|
||||
model_description: Dict[str, Any] = Field(
|
||||
description="The model description from Xinference."
|
||||
)
|
||||
_generator: Any = PrivateAttr()
|
||||
_model_uid: str = Field(description="The Xinference model to use.")
|
||||
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_uid: str,
|
||||
endpoint: str,
|
||||
embed_batch_size: int = EMBED_MAX_BATCH_SIZE,
|
||||
dimensions: Optional[int] = None,
|
||||
additional_kwargs: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
max_retries: int = 10,
|
||||
# timeout: float = 60.0,
|
||||
# reuse_client: bool = True,
|
||||
# callback_manager: Optional[CallbackManager] = None,
|
||||
# default_headers: Optional[Dict[str, str]] = None,
|
||||
# http_client: Optional[httpx.Client] = None,
|
||||
# async_http_client: Optional[httpx.AsyncClient] = None,
|
||||
# num_workers: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
generator, model_description, embed_batch_size, dimensions = self.load_model(
|
||||
model_uid, endpoint
|
||||
)
|
||||
self._generator = generator
|
||||
#self._model_uid = model_uid
|
||||
#self._endpoint = endpoint
|
||||
super().__init__(
|
||||
embed_batch_size=embed_batch_size,
|
||||
dimensions=dimensions,
|
||||
#callback_manager=callback_manager,
|
||||
model_name=model_uid,
|
||||
additional_kwargs=additional_kwargs,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
max_retries=max_retries,
|
||||
# reuse_client=reuse_client,
|
||||
# timeout=timeout,
|
||||
# default_headers=default_headers,
|
||||
# num_workers=num_workers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Xinference library."
|
||||
'Please install Xinference with `pip install "xinference[all]"`'
|
||||
)
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
try:
|
||||
assert isinstance(client, RESTfulClient)
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not create RESTfulClient instance."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
generator = client.get_model(model_uid)
|
||||
model_description = client.list_models()[model_uid]
|
||||
|
||||
try:
|
||||
assert generator is not None
|
||||
assert model_description is not None
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not get model from endpoint."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
model = model_description["model_name"]
|
||||
replica = model_description['replica']
|
||||
dimensions = model_description['dimensions']
|
||||
max_tokens = model_description['max_tokens']
|
||||
|
||||
return generator, model_description, replica, dimensions
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "XinferenceEmbedding"
|
||||
|
||||
def _get_text_embedding(self, text: str) -> Embedding:
|
||||
"""
|
||||
Embed the input text synchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_text_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
assert self._generator is not None
|
||||
|
||||
response = self._generator.create_embedding(input=text)
|
||||
return response['data'][0]['embedding']
|
||||
|
||||
def _get_query_embedding(self, query: str) -> Embedding:
|
||||
"""
|
||||
Embed the input query synchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_query_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
return self._get_text_embedding(query)
|
||||
|
||||
async def _aget_query_embedding(self, query: str) -> Embedding:
|
||||
"""
|
||||
Embed the input query asynchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_query_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
return self._get_query_embedding(query)
|
||||
|
||||
class XinferenceRerank(BaseNodePostprocessor):
|
||||
"""Xinference class for rerank.
|
||||
|
||||
"""
|
||||
model_description: Dict[str, Any] = Field(
|
||||
description="The model description from Xinference."
|
||||
)
|
||||
_generator: Any = PrivateAttr()
|
||||
_model_uid: str = Field(description="The Xinference model to use.")
|
||||
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||
model: str = Field(description="Dashscope rerank model name.")
|
||||
top_n: int = Field(description="Top N nodes to return.")
|
||||
threshold: float = Field(description="threshold nodes to return.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_uid: str,
|
||||
endpoint: str,
|
||||
top_n: int = None,
|
||||
threshold: float = None,
|
||||
return_documents: bool = False
|
||||
):
|
||||
_model_uid = model_uid
|
||||
_endpoint = endpoint
|
||||
_op_n = top_n
|
||||
threshold = threshold
|
||||
generator, model_description = self.load_model(
|
||||
model_uid, endpoint
|
||||
)
|
||||
self._generator = generator
|
||||
super().__init__(top_n=top_n, model=model_uid, model_uid=model_uid, threshold = threshold, return_documents=return_documents)
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "XinferenceRerank"
|
||||
|
||||
def _postprocess_nodes(
|
||||
self,
|
||||
nodes: List[NodeWithScore],
|
||||
query_bundle: Optional[QueryBundle] = None,
|
||||
) -> List[NodeWithScore]:
|
||||
if query_bundle is None:
|
||||
raise ValueError("Missing query bundle in extra info.")
|
||||
if len(nodes) == 0:
|
||||
return []
|
||||
|
||||
dispatcher.event(
|
||||
ReRankStartEvent(
|
||||
nodes = nodes,
|
||||
top_n = self.top_n,
|
||||
query = query_bundle,
|
||||
model_name = self.model
|
||||
)
|
||||
)
|
||||
|
||||
with self.callback_manager.event(
|
||||
CBEventType.RERANKING,
|
||||
payload={
|
||||
EventPayload.NODES: nodes,
|
||||
EventPayload.MODEL_NAME: self._model_uid,
|
||||
EventPayload.QUERY_STR: query_bundle.query_str,
|
||||
EventPayload.TOP_K: self.top_n,
|
||||
},
|
||||
) as event:
|
||||
texts = [node.node.get_content() for node in nodes]
|
||||
response = self._generator.rerank(texts,query_bundle.query_str)
|
||||
new_nodes = []
|
||||
for result in response['results']:
|
||||
new_node_with_score = NodeWithScore(
|
||||
node=nodes[result['index']].node, score=result['relevance_score']
|
||||
)
|
||||
if self.threshold is not None:
|
||||
if new_node_with_score.score >=self.threshold:
|
||||
new_nodes.append(new_node_with_score)
|
||||
|
||||
if self.top_n is not None:
|
||||
if len(new_nodes) > self.top_n:
|
||||
for index in new_nodes[self.top_n:-1]:
|
||||
new_nodes.remove(index)
|
||||
|
||||
event.on_end(payload={EventPayload.NODES: new_nodes})
|
||||
|
||||
dispatcher.event(
|
||||
ReRankEndEvent(
|
||||
nodes= new_nodes
|
||||
)
|
||||
)
|
||||
return new_nodes
|
||||
|
||||
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Xinference library."
|
||||
'Please install Xinference with `pip install "xinference[all]"`'
|
||||
)
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
try:
|
||||
assert isinstance(client, RESTfulClient)
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not create RESTfulClient instance."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
generator = client.get_model(model_uid)
|
||||
model_description = client.list_models()[model_uid]
|
||||
|
||||
try:
|
||||
assert generator is not None
|
||||
assert model_description is not None
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not get model from endpoint."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
model = model_description["model_name"]
|
||||
|
||||
return generator, model_description
|
||||
Reference in New Issue
Block a user