Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| d3df62f454 | |||
| 1bfb28c40c | |||
| 092d9705a7 |
@@ -0,0 +1,103 @@
|
||||
# The Llama Cloud API key.
|
||||
# LLAMA_CLOUD_API_KEY=
|
||||
SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zjinfo1
|
||||
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||
|
||||
#---------- Xinference ----------------
|
||||
# The provider for the AI models to use.
|
||||
MODEL_PROVIDER=xinference
|
||||
# The OpenAI API key to use.
|
||||
OPENAI_API_KEY=xinference
|
||||
BASE_URL=http://10.1.0.142:9995
|
||||
MODEL=Qwen2-72B-Instruct-GPTQ-Int8
|
||||
# Temperature for sampling from the model.
|
||||
LLM_TEMPERATURE=0.1
|
||||
# Maximum number of tokens to generate.
|
||||
#LLM_MAX_TOKENS=
|
||||
# Name of the embedding model to use.
|
||||
EMBEDDING_MODEL=bge-m3
|
||||
EMBEDDING_BASE_URL=http://10.1.16.39:9995
|
||||
# Dimension of the embedding model to use.
|
||||
EMBEDDING_DIM=1024
|
||||
##---------- OpenAI ----------------
|
||||
## The provider for the AI models to use.
|
||||
#MODEL_PROVIDER=openai
|
||||
## The OpenAI API key to use.
|
||||
#OPENAI_API_KEY=xinference
|
||||
#BASE_URL=http://10.1.0.142:9995/v1
|
||||
#MODEL=Qwen2-72B-Instruct-GPTQ-Int4
|
||||
## Temperature for sampling from the model.
|
||||
#LLM_TEMPERATURE=0.1
|
||||
## Maximum number of tokens to generate.
|
||||
##LLM_MAX_TOKENS=
|
||||
## Name of the embedding model to use.
|
||||
#EMBEDDING_MODEL=text-embedding-v2
|
||||
## Dimension of the embedding model to use.
|
||||
#EMBEDDING_DIM=1024
|
||||
#---------- DashScope ----------------
|
||||
#DASHSCOPE_API_KEY=sk-02c8540e86d84b7ca0e6f4f51bac6e60
|
||||
## The provider for the AI models to use.
|
||||
#MODEL_PROVIDER=dashscope
|
||||
## The name of LLM model to use.
|
||||
#MODEL=qwen-max
|
||||
## Name of the embedding model to use.
|
||||
#EMBEDDING_MODEL=text-embedding-v2
|
||||
|
||||
#--------------------------
|
||||
# 是否启用检索重排功能
|
||||
ENABLE_RERANK=true
|
||||
|
||||
|
||||
# The questions to help users get started (multi-line).
|
||||
CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容?
|
||||
|
||||
# The number of similar embeddings to return when retrieving documents.
|
||||
TOP_K=5
|
||||
|
||||
# The time in milliseconds to wait for the stream to return a response.
|
||||
STREAM_TIMEOUT=60000
|
||||
|
||||
# 向量存储数据库类型,目前可选:chroma、qdrant
|
||||
VECTOR_STORE_TYPE=chroma
|
||||
# The name of the collection in your vector database
|
||||
VECTOR_STORE_COLLECTION=default
|
||||
|
||||
# The API endpoint for your vector database
|
||||
# VECTOR_STORE_HOST=
|
||||
|
||||
# The port for your vector database
|
||||
# VECTOR_STORE_PORT=
|
||||
|
||||
# The local path to the vector database.
|
||||
# Specify this if you are using a local vector database.
|
||||
# Otherwise, use VECTOR_STORE__HOST and VECTOR_STORE__PORT config above
|
||||
VECTOR_STORE_PATH=./storage_vector
|
||||
|
||||
|
||||
|
||||
PHOENIX_API_KEY=123456
|
||||
PHOENIX_URL=http://localhost:6006/v1/traces
|
||||
PHOENIX_PROJECT_NAME=ly_zjapp
|
||||
#OTEL_SERVICE_NAME=ly_zjapp
|
||||
#OTEL_RESOURCE_ATTRIBUTES=openinference.project.name=ly_zjapp
|
||||
# The address to start the backend app.
|
||||
APP_HOST=0.0.0.0
|
||||
|
||||
# The port to start the backend app.
|
||||
APP_PORT=8000
|
||||
|
||||
FILESERVER_URL_PREFIX=/api/files
|
||||
|
||||
# E2B_API_KEY key is required to run code interpreter tool. Get it here: https://e2b.dev/docs/getting-started/api-key
|
||||
# E2B_API_KEY=
|
||||
|
||||
# The system prompt for the AI model.
|
||||
SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weather forecast for a given location.
|
||||
-You are a Python interpreter that can run any python code in a secure environment.
|
||||
- The python code runs in a Jupyter notebook. Every time you call the 'interpreter' tool, the python code is executed in a separate cell.
|
||||
- You are given tasks to complete and you run python code to solve them.
|
||||
- It's okay to make multiple calls to interpreter tool. If you get an error or the result is not what you expected, you can call the tool again. Don't give up too soon!
|
||||
- Plot visualizations using matplotlib or any other visualization library directly in the notebook.
|
||||
- You can install any pip package (if it exists) by running a cell with pip install.
|
||||
"
|
||||
|
||||
+21
-1
@@ -3,6 +3,10 @@ from typing import Dict
|
||||
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.llms.xinference import Xinference
|
||||
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
||||
|
||||
from app.xinference.base import XinferenceEmbedding
|
||||
|
||||
|
||||
def init_settings():
|
||||
@@ -26,8 +30,9 @@ def init_settings():
|
||||
init_azure_openai()
|
||||
case "t-systems":
|
||||
from .llmhub import init_llmhub
|
||||
|
||||
init_llmhub()
|
||||
case "xinference":
|
||||
init_xinference()
|
||||
case _:
|
||||
raise ValueError(f"Invalid model provider: {model_provider}")
|
||||
|
||||
@@ -52,6 +57,21 @@ def init_ollama():
|
||||
# )
|
||||
pass
|
||||
|
||||
def init_xinference():
|
||||
base_url = os.getenv("BASE_URL")
|
||||
model = os.getenv("MODEL")
|
||||
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
|
||||
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
|
||||
|
||||
Settings.llm = Xinference(model, base_url, temperature, max_tokens)
|
||||
|
||||
embedding_base_url = os.getenv("EMBEDDING_BASE_URL")
|
||||
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url
|
||||
|
||||
embed_model_name = os.getenv("EMBEDDING_MODEL")
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
dimensions = int(dimensions) if dimensions is not None else None
|
||||
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url)
|
||||
|
||||
def init_openai():
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Xinference embeddings file."""
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from typing import Any, Dict, List, Optional, Union, Tuple
|
||||
|
||||
from llama_index.core.base.embeddings.base import BaseEmbedding, Embedding
|
||||
from llama_index.core.bridge.pydantic import PrivateAttr
|
||||
from llama_index.core.embeddings.multi_modal_base import MultiModalEmbedding
|
||||
from llama_index.core.schema import ImageType
|
||||
from pydantic import Field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# class XinferenceTextEmbeddingType(str, Enum):
|
||||
# """DashScope TextEmbedding text_type."""
|
||||
#
|
||||
# TEXT_TYPE_QUERY = "query"
|
||||
# TEXT_TYPE_DOCUMENT = "document"
|
||||
#
|
||||
#
|
||||
# class DashScopeTextEmbeddingModels(str, Enum):
|
||||
# """DashScope TextEmbedding models."""
|
||||
#
|
||||
# TEXT_EMBEDDING_V1 = "text-embedding-v1"
|
||||
# TEXT_EMBEDDING_V2 = "text-embedding-v2"
|
||||
# TEXT_EMBEDDING_V3 = "text-embedding-v3"
|
||||
#
|
||||
#
|
||||
# class DashScopeBatchTextEmbeddingModels(str, Enum):
|
||||
# """DashScope TextEmbedding models."""
|
||||
#
|
||||
# TEXT_EMBEDDING_ASYNC_V1 = "text-embedding-async-v1"
|
||||
# TEXT_EMBEDDING_ASYNC_V2 = "text-embedding-async-v2"
|
||||
# TEXT_EMBEDDING_ASYNC_V3 = "text-embedding-async-v3"
|
||||
|
||||
|
||||
EMBED_MAX_INPUT_LENGTH = 2048
|
||||
EMBED_MAX_BATCH_SIZE = 1
|
||||
|
||||
|
||||
# class DashScopeMultiModalEmbeddingModels(str, Enum):
|
||||
# """DashScope MultiModalEmbedding models."""
|
||||
#
|
||||
# MULTIMODAL_EMBEDDING_ONE_PEACE_V1 = "multimodal-embedding-one-peace-v1"
|
||||
|
||||
|
||||
# def get_text_embedding(
|
||||
# model: str,
|
||||
# text: Union[str, List[str]],
|
||||
# api_key: Optional[str] = None,
|
||||
# **kwargs: Any,
|
||||
# ) -> List[List[float]]:
|
||||
# """Call DashScope text embedding.
|
||||
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-api-details.
|
||||
#
|
||||
# Args:
|
||||
# model (str): The `DashScopeTextEmbeddingModels`
|
||||
# text (Union[str, List[str]]): text or list text to embedding.
|
||||
#
|
||||
# Raises:
|
||||
# ImportError: need import dashscope
|
||||
#
|
||||
# Returns:
|
||||
# List[List[float]]: The list of embedding result, if failed return empty list.
|
||||
# if some of test no output, the correspond index of output is None.
|
||||
# """
|
||||
# try:
|
||||
# import dashscope
|
||||
# except ImportError:
|
||||
# raise ImportError("DashScope requires `pip install dashscope")
|
||||
# if isinstance(text, str):
|
||||
# text = [text]
|
||||
# response = dashscope.TextEmbedding.call(
|
||||
# model=model, input=text, api_key=api_key, kwargs=kwargs
|
||||
# )
|
||||
# embedding_results = [None] * len(text)
|
||||
# if response.status_code == HTTPStatus.OK:
|
||||
# for emb in response.output["embeddings"]:
|
||||
# embedding_results[emb["text_index"]] = emb["embedding"]
|
||||
# else:
|
||||
# logger.error("Calling TextEmbedding failed, details: %s" % response)
|
||||
#
|
||||
# return embedding_results
|
||||
#
|
||||
#
|
||||
# def get_batch_text_embedding(
|
||||
# model: str, url: str, api_key: Optional[str] = None, **kwargs: Any
|
||||
# ) -> Optional[str]:
|
||||
# """Call DashScope batch text embedding.
|
||||
#
|
||||
# Args:
|
||||
# model (str): The `DashScopeMultiModalEmbeddingModels`
|
||||
# url (str): The url of the file to embedding which with lines of text to embedding.
|
||||
#
|
||||
# Raises:
|
||||
# ImportError: Need install dashscope package.
|
||||
#
|
||||
# Returns:
|
||||
# str: The url of the embedding result, format ref:
|
||||
# https://help.aliyun.com/zh/dashscope/developer-reference/text-embedding-async-api-details
|
||||
# """
|
||||
# try:
|
||||
# import dashscope
|
||||
# except ImportError:
|
||||
# raise ImportError("DashScope requires `pip install dashscope")
|
||||
# response = dashscope.BatchTextEmbedding.call(
|
||||
# model=model, url=url, api_key=api_key, kwargs=kwargs
|
||||
# )
|
||||
# if response.status_code == HTTPStatus.OK:
|
||||
# return response.output["url"]
|
||||
# else:
|
||||
# logger.error("Calling BatchTextEmbedding failed, details: %s" % response)
|
||||
# return None
|
||||
|
||||
|
||||
# def get_multimodal_embedding(
|
||||
# model: str, input: list, api_key: Optional[str] = None, **kwargs: Any
|
||||
# ) -> List[float]:
|
||||
# """Call DashScope multimodal embedding.
|
||||
# ref: https://help.aliyun.com/zh/dashscope/developer-reference/one-peace-multimodal-embedding-api-details.
|
||||
#
|
||||
# Args:
|
||||
# model (str): The `DashScopeBatchTextEmbeddingModels`
|
||||
# input (str): The input of the embedding, eg:
|
||||
# [{'factor': 1, 'text': '你好'},
|
||||
# {'factor': 2, 'audio': 'https://dashscope.oss-cn-beijing.aliyuncs.com/audios/cow.flac'},
|
||||
# {'factor': 3, 'image': 'https://dashscope.oss-cn-beijing.aliyuncs.com/images/256_1.png'}]
|
||||
#
|
||||
# Raises:
|
||||
# ImportError: Need install dashscope package.
|
||||
#
|
||||
# Returns:
|
||||
# List[float]: Embedding result, if failed return empty list.
|
||||
# """
|
||||
# try:
|
||||
# import dashscope
|
||||
# except ImportError:
|
||||
# raise ImportError("DashScope requires `pip install dashscope")
|
||||
# response = dashscope.MultiModalEmbedding.call(
|
||||
# model=model, input=input, api_key=api_key, kwargs=kwargs
|
||||
# )
|
||||
# if response.status_code == HTTPStatus.OK:
|
||||
# return response.output["embedding"]
|
||||
# else:
|
||||
# logger.error("Calling MultiModalEmbedding failed, details: %s" % response)
|
||||
# return []
|
||||
|
||||
class XinferenceEmbedding(BaseEmbedding):
|
||||
"""Xinference class for text embedding.
|
||||
|
||||
"""
|
||||
model_description: Dict[str, Any] = Field(
|
||||
description="The model description from Xinference."
|
||||
)
|
||||
_generator: Any = PrivateAttr()
|
||||
_model_uid: str = Field(description="The Xinference model to use.")
|
||||
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_uid: str,
|
||||
endpoint: str,
|
||||
embed_batch_size: int = EMBED_MAX_BATCH_SIZE,
|
||||
dimensions: Optional[int] = None,
|
||||
additional_kwargs: Optional[Dict[str, Any]] = None,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
max_retries: int = 10,
|
||||
# timeout: float = 60.0,
|
||||
# reuse_client: bool = True,
|
||||
# callback_manager: Optional[CallbackManager] = None,
|
||||
# default_headers: Optional[Dict[str, str]] = None,
|
||||
# http_client: Optional[httpx.Client] = None,
|
||||
# async_http_client: Optional[httpx.AsyncClient] = None,
|
||||
# num_workers: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
generator, model_description = self.load_model(
|
||||
model_uid, endpoint
|
||||
)
|
||||
self._generator = generator
|
||||
#self._model_uid = model_uid
|
||||
#self._endpoint = endpoint
|
||||
super().__init__(
|
||||
embed_batch_size=embed_batch_size,
|
||||
dimensions=dimensions,
|
||||
#callback_manager=callback_manager,
|
||||
model_name=model_uid,
|
||||
additional_kwargs=additional_kwargs,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
max_retries=max_retries,
|
||||
# reuse_client=reuse_client,
|
||||
# timeout=timeout,
|
||||
# default_headers=default_headers,
|
||||
# num_workers=num_workers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]:
|
||||
try:
|
||||
from xinference.client import RESTfulClient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import Xinference library."
|
||||
'Please install Xinference with `pip install "xinference[all]"`'
|
||||
)
|
||||
|
||||
client = RESTfulClient(endpoint)
|
||||
|
||||
try:
|
||||
assert isinstance(client, RESTfulClient)
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not create RESTfulClient instance."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
generator = client.get_model(model_uid)
|
||||
model_description = client.list_models()[model_uid]
|
||||
|
||||
try:
|
||||
assert generator is not None
|
||||
assert model_description is not None
|
||||
except AssertionError:
|
||||
raise RuntimeError(
|
||||
"Could not get model from endpoint."
|
||||
"Please make sure Xinference endpoint is running at the correct port."
|
||||
)
|
||||
|
||||
model = model_description["model_name"]
|
||||
|
||||
return generator, model_description
|
||||
|
||||
@classmethod
|
||||
def class_name(cls) -> str:
|
||||
return "XinferenceEmbedding"
|
||||
|
||||
def _get_text_embedding(self, text: str) -> Embedding:
|
||||
"""
|
||||
Embed the input text synchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_text_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
assert self._generator is not None
|
||||
|
||||
response = self._generator.create_embedding(input=text)
|
||||
return response['data'][0]['embedding']
|
||||
|
||||
def _get_query_embedding(self, query: str) -> Embedding:
|
||||
"""
|
||||
Embed the input query synchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_query_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
return self._get_text_embedding(query)
|
||||
|
||||
async def _aget_query_embedding(self, query: str) -> Embedding:
|
||||
"""
|
||||
Embed the input query asynchronously.
|
||||
|
||||
Subclasses should implement this method. Reference get_query_embedding's
|
||||
docstring for more information.
|
||||
"""
|
||||
return self._get_query_embedding(query)
|
||||
@@ -23,6 +23,9 @@ llama-index-callbacks-arize-phoenix = "^0.1.4"
|
||||
llama-index-llms-dashscope = "^0.1.2"
|
||||
llama-index-embeddings-dashscope = "^0.1.4"
|
||||
llama-index-postprocessor-dashscope-rerank-custom = "0.1.0"
|
||||
#xinference = "^0.14.1"
|
||||
xinference.client = "^0.14.1"
|
||||
llama-index-llms-xinference = "^0.1.2"
|
||||
qdrant-client="^1.10.1"
|
||||
llama-index-vector-stores-qdrant = "^0.2.14"
|
||||
chroma="^0.5.5"
|
||||
|
||||
Reference in New Issue
Block a user