增加对xinference的支持。

This commit is contained in:
2024-08-14 08:51:51 +08:00
parent 092d9705a7
commit 1bfb28c40c
2 changed files with 24 additions and 1 deletions
+21 -1
View File
@@ -3,6 +3,10 @@ from typing import Dict
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.settings import Settings
from llama_index.llms.xinference import Xinference
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from app.xinference.base import XinferenceEmbedding
def init_settings():
@@ -26,8 +30,9 @@ def init_settings():
init_azure_openai()
case "t-systems":
from .llmhub import init_llmhub
init_llmhub()
case "xinference":
init_xinference()
case _:
raise ValueError(f"Invalid model provider: {model_provider}")
@@ -52,6 +57,21 @@ def init_ollama():
# )
pass
def init_xinference():
base_url = os.getenv("BASE_URL")
model = os.getenv("MODEL")
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
Settings.llm = Xinference(model, base_url, temperature, max_tokens)
embedding_base_url = os.getenv("EMBEDDING_BASE_URL")
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url
embed_model_name = os.getenv("EMBEDDING_MODEL")
dimensions = os.getenv("EMBEDDING_DIM")
dimensions = int(dimensions) if dimensions is not None else None
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url)
def init_openai():
from llama_index.core.constants import DEFAULT_TEMPERATURE