Compare commits

3 Commits

Author SHA1 Message Date
ly d3df62f454 增加xinference的配置支持 2024-08-14 08:52:56 +08:00
ly 1bfb28c40c 增加对xinference的支持。 2024-08-14 08:51:51 +08:00
ly 092d9705a7 增加对XinferenceEmbedding的支持,临死放到这里。 2024-08-14 08:50:19 +08:00
5 changed files with 399 additions and 1 deletions
+103
View File
@@ -0,0 +1,103 @@
# The Llama Cloud API key.
# LLAMA_CLOUD_API_KEY=
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
#---------- Xinference ----------------
# The provider for the AI models to use.
MODEL_PROVIDER=xinference
# The OpenAI API key to use.
OPENAI_API_KEY=xinference
BASE_URL=http://10.1.0.142:9995
MODEL=Qwen2-72B-Instruct-GPTQ-Int8
# Temperature for sampling from the model.
LLM_TEMPERATURE=0.1
# Maximum number of tokens to generate.
#LLM_MAX_TOKENS=
# Name of the embedding model to use.
EMBEDDING_MODEL=bge-m3
EMBEDDING_BASE_URL=http://10.1.16.39:9995
# Dimension of the embedding model to use.
EMBEDDING_DIM=1024
##---------- OpenAI ----------------
## The provider for the AI models to use.
#MODEL_PROVIDER=openai
## The OpenAI API key to use.
#OPENAI_API_KEY=xinference
#BASE_URL=http://10.1.0.142:9995/v1
#MODEL=Qwen2-72B-Instruct-GPTQ-Int4
## Temperature for sampling from the model.
#LLM_TEMPERATURE=0.1
## Maximum number of tokens to generate.
##LLM_MAX_TOKENS=
## Name of the embedding model to use.
#EMBEDDING_MODEL=text-embedding-v2
## Dimension of the embedding model to use.
#EMBEDDING_DIM=1024
#---------- DashScope ----------------
#DASHSCOPE_API_KEY=sk-02c8540e86d84b7ca0e6f4f51bac6e60
## The provider for the AI models to use.
#MODEL_PROVIDER=dashscope
## The name of LLM model to use.
#MODEL=qwen-max
## Name of the embedding model to use.
#EMBEDDING_MODEL=text-embedding-v2
#--------------------------
# 是否启用检索重排功能
ENABLE_RERANK=true
# The questions to help users get started (multi-line).
CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容?
# The number of similar embeddings to return when retrieving documents.
TOP_K=5
# The time in milliseconds to wait for the stream to return a response.
STREAM_TIMEOUT=60000
# 向量存储数据库类型,目前可选:chroma、qdrant
VECTOR_STORE_TYPE=chroma
# The name of the collection in your vector database
VECTOR_STORE_COLLECTION=default
# The API endpoint for your vector database
# VECTOR_STORE_HOST=
# The port for your vector database
# VECTOR_STORE_PORT=
# The local path to the vector database.
# Specify this if you are using a local vector database.
# Otherwise, use VECTOR_STORE__HOST and VECTOR_STORE__PORT config above
VECTOR_STORE_PATH=./storage_vector
PHOENIX_API_KEY=123456
PHOENIX_URL=http://localhost:6006/v1/traces
PHOENIX_PROJECT_NAME=ly_zjapp
#OTEL_SERVICE_NAME=ly_zjapp
#OTEL_RESOURCE_ATTRIBUTES=openinference.project.name=ly_zjapp
# The address to start the backend app.
APP_HOST=0.0.0.0
# The port to start the backend app.
APP_PORT=8000
FILESERVER_URL_PREFIX=/api/files
# E2B_API_KEY key is required to run code interpreter tool. Get it here: https://e2b.dev/docs/getting-started/api-key
# E2B_API_KEY=
# The system prompt for the AI model.
SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weather forecast for a given location.
-You are a Python interpreter that can run any python code in a secure environment.
- The python code runs in a Jupyter notebook. Every time you call the 'interpreter' tool, the python code is executed in a separate cell.
- You are given tasks to complete and you run python code to solve them.
- It's okay to make multiple calls to interpreter tool. If you get an error or the result is not what you expected, you can call the tool again. Don't give up too soon!
- Plot visualizations using matplotlib or any other visualization library directly in the notebook.
- You can install any pip package (if it exists) by running a cell with pip install.
"
+21 -1
View File
@@ -3,6 +3,10 @@ from typing import Dict
from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.settings import Settings from llama_index.core.settings import Settings
from llama_index.llms.xinference import Xinference
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from app.xinference.base import XinferenceEmbedding
def init_settings(): def init_settings():
@@ -26,8 +30,9 @@ def init_settings():
init_azure_openai() init_azure_openai()
case "t-systems": case "t-systems":
from .llmhub import init_llmhub from .llmhub import init_llmhub
init_llmhub() init_llmhub()
case "xinference":
init_xinference()
case _: case _:
raise ValueError(f"Invalid model provider: {model_provider}") raise ValueError(f"Invalid model provider: {model_provider}")
@@ -52,6 +57,21 @@ def init_ollama():
# ) # )
pass pass
def init_xinference():
base_url = os.getenv("BASE_URL")
model = os.getenv("MODEL")
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
Settings.llm = Xinference(model, base_url, temperature, max_tokens)
embedding_base_url = os.getenv("EMBEDDING_BASE_URL")
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url
embed_model_name = os.getenv("EMBEDDING_MODEL")
dimensions = os.getenv("EMBEDDING_DIM")
dimensions = int(dimensions) if dimensions is not None else None
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url)
def init_openai(): def init_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.core.constants import DEFAULT_TEMPERATURE
View File
+272
View File
@@ -0,0 +1,272 @@
"""Xinference embeddings file."""
import logging
from enum import Enum
from http import HTTPStatus
from typing import Any, Dict, List, Optional, Union, Tuple
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
from llama_index.core.bridge.pydantic import PrivateAttr
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
from llama_index.core.schema import ImageType
from pydantic import Field
logger = logging.getLogger(__name__)
# class XinferenceTextEmbeddingType(str, Enum):
# """DashScope TextEmbedding text_type."""
#
# TEXT_TYPE_QUERY = "query"
# TEXT_TYPE_DOCUMENT = "document"
#
#
# class DashScopeTextEmbeddingModels(str, Enum):
# """DashScope TextEmbedding models."""
#
# TEXT_EMBEDDING_V1 = "text-embedding-v1"
# TEXT_EMBEDDING_V2 = "text-embedding-v2"
# TEXT_EMBEDDING_V3 = "text-embedding-v3"
#
#
# class DashScopeBatchTextEmbeddingModels(str, Enum):
# """DashScope TextEmbedding models."""
#
# TEXT_EMBEDDING_ASYNC_V1 = "text-embedding-async-v1"
# TEXT_EMBEDDING_ASYNC_V2 = "text-embedding-async-v2"
# TEXT_EMBEDDING_ASYNC_V3 = "text-embedding-async-v3"
EMBED_MAX_INPUT_LENGTH = 2048
EMBED_MAX_BATCH_SIZE = 1
# class DashScopeMultiModalEmbeddingModels(str, Enum):
# """DashScope MultiModalEmbedding models."""
#
# MULTIMODAL_EMBEDDING_ONE_PEACE_V1 = "multimodal-embedding-one-peace-v1"
# def get_text_embedding(
# model: str,
# text: Union[str, List[str]],
# api_key: Optional[str] = None,
# **kwargs: Any,
# ) -> List[List[float]]:
# """Call DashScope text embedding.
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details.
#
# Args:
# model (str): The `DashScopeTextEmbeddingModels`
# text (Union[str, List[str]]): text or list text to embedding.
#
# Raises:
# ImportError: need import dashscope
#
# Returns:
# List[List[float]]: The list of embedding result, if failed return empty list.
# if some of test no output, the correspond index of output is None.
# """
# try:
# import dashscope
# except ImportError:
# raise ImportError("DashScope requires `pip install dashscope")
# if isinstance(text, str):
# text = [text]
# response = dashscope.TextEmbedding.call(
# model=model, input=text, api_key=api_key, kwargs=kwargs
# )
# embedding_results = [None] * len(text)
# if response.status_code == HTTPStatus.OK:
# for emb in response.output["embeddings"]:
# embedding_results[emb["text_index"]] = emb["embedding"]
# else:
# logger.error("Calling TextEmbedding failed, details: %s" % response)
#
# return embedding_results
#
#
# def get_batch_text_embedding(
# model: str, url: str, api_key: Optional[str] = None, **kwargs: Any
# ) -> Optional[str]:
# """Call DashScope batch text embedding.
#
# Args:
# model (str): The `DashScopeMultiModalEmbeddingModels`
# url (str): The url of the file to embedding which with lines of text to embedding.
#
# Raises:
# ImportError: Need install dashscope package.
#
# Returns:
# str: The url of the embedding result, format ref:
# https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details
# """
# try:
# import dashscope
# except ImportError:
# raise ImportError("DashScope requires `pip install dashscope")
# response = dashscope.BatchTextEmbedding.call(
# model=model, url=url, api_key=api_key, kwargs=kwargs
# )
# if response.status_code == HTTPStatus.OK:
# return response.output["url"]
# else:
# logger.error("Calling BatchTextEmbedding failed, details: %s" % response)
# return None
# def get_multimodal_embedding(
# model: str, input: list, api_key: Optional[str] = None, **kwargs: Any
# ) -> List[float]:
# """Call DashScope multimodal embedding.
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/one-peace-multimodal-embedding-api-details.
#
# Args:
# model (str): The `DashScopeBatchTextEmbeddingModels`
# input (str): The input of the embedding, eg:
# [{'factor': 1, 'text': '你好'},
# {'factor': 2, 'audio': 'https://dashscope.oss-cn-beijing.aliyuncs.com/audios/cow.flac'},
# {'factor': 3, 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png'}]
#
# Raises:
# ImportError: Need install dashscope package.
#
# Returns:
# List[float]: Embedding result, if failed return empty list.
# """
# try:
# import dashscope
# except ImportError:
# raise ImportError("DashScope requires `pip install dashscope")
# response = dashscope.MultiModalEmbedding.call(
# model=model, input=input, api_key=api_key, kwargs=kwargs
# )
# if response.status_code == HTTPStatus.OK:
# return response.output["embedding"]
# else:
# logger.error("Calling MultiModalEmbedding failed, details: %s" % response)
# return []
class XinferenceEmbedding(BaseEmbedding):
"""Xinference class for text embedding.
"""
model_description: Dict[str, Any] = Field(
description="The model description from Xinference."
)
_generator: Any = PrivateAttr()
_model_uid: str = Field(description="The Xinference model to use.")
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
def __init__(
self,
model_uid: str,
endpoint: str,
embed_batch_size: int = EMBED_MAX_BATCH_SIZE,
dimensions: Optional[int] = None,
additional_kwargs: Optional[Dict[str, Any]] = None,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
max_retries: int = 10,
# timeout: float = 60.0,
# reuse_client: bool = True,
# callback_manager: Optional[CallbackManager] = None,
# default_headers: Optional[Dict[str, str]] = None,
# http_client: Optional[httpx.Client] = None,
# async_http_client: Optional[httpx.AsyncClient] = None,
# num_workers: Optional[int] = None,
**kwargs: Any,
) -> None:
generator, model_description = self.load_model(
model_uid, endpoint
)
self._generator = generator
#self._model_uid = model_uid
#self._endpoint = endpoint
super().__init__(
embed_batch_size=embed_batch_size,
dimensions=dimensions,
#callback_manager=callback_manager,
model_name=model_uid,
additional_kwargs=additional_kwargs,
api_key=api_key,
api_base=api_base,
api_version=api_version,
max_retries=max_retries,
# reuse_client=reuse_client,
# timeout=timeout,
# default_headers=default_headers,
# num_workers=num_workers,
**kwargs,
)
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
try:
from xinference.client import RESTfulClient
except ImportError:
raise ImportError(
"Could not import Xinference library."
'Please install Xinference with `pip install "xinference[all]"`'
)
client = RESTfulClient(endpoint)
try:
assert isinstance(client, RESTfulClient)
except AssertionError:
raise RuntimeError(
"Could not create RESTfulClient instance."
"Please make sure Xinference endpoint is running at the correct port."
)
generator = client.get_model(model_uid)
model_description = client.list_models()[model_uid]
try:
assert generator is not None
assert model_description is not None
except AssertionError:
raise RuntimeError(
"Could not get model from endpoint."
"Please make sure Xinference endpoint is running at the correct port."
)
model = model_description["model_name"]
return generator, model_description
@classmethod
def class_name(cls) -> str:
return "XinferenceEmbedding"
def _get_text_embedding(self, text: str) -> Embedding:
"""
Embed the input text synchronously.
Subclasses should implement this method. Reference get_text_embedding's
docstring for more information.
"""
assert self._generator is not None
response = self._generator.create_embedding(input=text)
return response['data'][0]['embedding']
def _get_query_embedding(self, query: str) -> Embedding:
"""
Embed the input query synchronously.
Subclasses should implement this method. Reference get_query_embedding's
docstring for more information.
"""
return self._get_text_embedding(query)
async def _aget_query_embedding(self, query: str) -> Embedding:
"""
Embed the input query asynchronously.
Subclasses should implement this method. Reference get_query_embedding's
docstring for more information.
"""
return self._get_query_embedding(query)
+3
View File
@@ -23,6 +23,9 @@ llama-index-callbacks-arize-phoenix = "^0.1.4"
llama-index-llms-dashscope = "^0.1.2" llama-index-llms-dashscope = "^0.1.2"
llama-index-embeddings-dashscope = "^0.1.4" llama-index-embeddings-dashscope = "^0.1.4"
llama-index-postprocessor-dashscope-rerank-custom = "0.1.0" llama-index-postprocessor-dashscope-rerank-custom = "0.1.0"
#xinference = "^0.14.1"
xinference.client = "^0.14.1"
llama-index-llms-xinference = "^0.1.2"
qdrant-client="^1.10.1" qdrant-client="^1.10.1"
llama-index-vector-stores-qdrant = "^0.2.14" llama-index-vector-stores-qdrant = "^0.2.14"
chroma="^0.5.5" chroma="^0.5.5"