From 1bfb28c40cea93a1ef1f5c9a7bb58cefe65a7c95 Mon Sep 17 00:00:00 2001 From: paituo <330435863@qq.com> Date: Wed, 14 Aug 2024 08:51:51 +0800 Subject: [PATCH] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AF=B9xinference=E7=9A=84?= =?UTF-8?q?=E6=94=AF=E6=8C=81=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- backend/app/settings.py | 22 +++++++++++++++++++++- backend/pyproject.toml | 3 +++ 2 files changed, 24 insertions(+), 1 deletion(-) diff --git a/backend/app/settings.py b/backend/app/settings.py index 0158074..82e83e8 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -3,6 +3,10 @@ from typing import Dict from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.core.settings import Settings +from llama_index.llms.xinference import Xinference +from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP + +from app.xinference.base import XinferenceEmbedding def init_settings(): @@ -26,8 +30,9 @@ def init_settings(): init_azure_openai() case "t-systems": from .llmhub import init_llmhub - init_llmhub() + case "xinference": + init_xinference() case _: raise ValueError(f"Invalid model provider: {model_provider}") @@ -52,6 +57,21 @@ def init_ollama(): # ) pass +def init_xinference(): + base_url = os.getenv("BASE_URL") + model = os.getenv("MODEL") + max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None + temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP)) + + Settings.llm = Xinference(model, base_url, temperature, max_tokens) + + embedding_base_url = os.getenv("EMBEDDING_BASE_URL") + embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url + + embed_model_name = os.getenv("EMBEDDING_MODEL") + dimensions = os.getenv("EMBEDDING_DIM") + dimensions = int(dimensions) if dimensions is not None else None + Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url) def init_openai(): from llama_index.core.constants import DEFAULT_TEMPERATURE diff --git a/backend/pyproject.toml b/backend/pyproject.toml index 527c0ef..f244939 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -23,6 +23,9 @@ llama-index-callbacks-arize-phoenix = "^0.1.4" llama-index-llms-dashscope = "^0.1.2" llama-index-embeddings-dashscope = "^0.1.4" llama-index-postprocessor-dashscope-rerank-custom = "0.1.0" +#xinference = "^0.14.1" +xinference.client = "^0.14.1" +llama-index-llms-xinference = "^0.1.2" qdrant-client="^1.10.1" llama-index-vector-stores-qdrant = "^0.2.14" chroma="^0.5.5"