优化模型初始化代码
This commit is contained in:
+286
-190
@@ -1,6 +1,6 @@
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
from abc import abstractmethod
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.core.settings import Settings
|
||||
from llama_index.llms.xinference import Xinference
|
||||
@@ -9,229 +9,322 @@ from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
||||
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||
from app.engine.loaders import getProjectInfos
|
||||
from app.api.routers.request.base import ProjectInfo
|
||||
from util.register import *
|
||||
from llama_index.core.callbacks import CallbackManager
|
||||
from modelProvide.customDashScope import CustomDashScope
|
||||
|
||||
ModelPlateCategory = '模型平台'
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||
if rerank_enabled is None or rerank_enabled == 'False':
|
||||
return []
|
||||
|
||||
rerank_model = os.getenv("RERANK_MODEL")
|
||||
rerank_url = os.getenv("RERANK_BASE_URL")
|
||||
rerank_top_n = os.getenv("RERANK_TOP_N")
|
||||
rerank_threshold = os.getenv("RERANK_THRESHOLD")
|
||||
|
||||
Rerank_provider = os.getenv("RERANK_PROVIDER")
|
||||
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,Rerank_provider)
|
||||
postprocess = None
|
||||
if rerank_model is not None:
|
||||
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
|
||||
if modelPaltCls is not None:
|
||||
modelPalt:ModelPlatform = modelPaltCls()
|
||||
postprocess = modelPalt.rerank()
|
||||
else:
|
||||
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
|
||||
return postprocess
|
||||
|
||||
def init_settings():
|
||||
model_provider = os.getenv("MODEL_PROVIDER")
|
||||
match model_provider:
|
||||
case "openai":
|
||||
init_openai()
|
||||
case "dashscope":
|
||||
init_dashscope()
|
||||
case "groq":
|
||||
init_groq()
|
||||
case "ollama":
|
||||
init_ollama()
|
||||
case "anthropic":
|
||||
init_anthropic()
|
||||
case "gemini":
|
||||
init_gemini()
|
||||
case "mistral":
|
||||
init_mistral()
|
||||
case "azure-openai":
|
||||
init_azure_openai()
|
||||
case "t-systems":
|
||||
from .llmhub import init_llmhub
|
||||
init_llmhub()
|
||||
case "xinference":
|
||||
init_xinference()
|
||||
case _:
|
||||
raise ValueError(f"Invalid model provider: {model_provider}")
|
||||
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)
|
||||
if modelPaltCls is not None:
|
||||
modelPalt:ModelPlatform = modelPaltCls()
|
||||
Settings.llm = modelPalt.model()
|
||||
else:
|
||||
raise ValueError(f"Invalid model provider: {model_provider}")
|
||||
|
||||
embedding_provider = os.getenv("EMBEDDING_PROVIDER")
|
||||
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,embedding_provider)
|
||||
if modelPalt is not None:
|
||||
modelPalt:ModelPlatform = modelPaltCls()
|
||||
Settings.embed_model = modelPalt.embedding()
|
||||
else:
|
||||
raise ValueError(f"Invalid embedding provider: {embedding_provider}")
|
||||
|
||||
Settings.llm.callback_manager = CallbackManager()
|
||||
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
|
||||
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
|
||||
|
||||
def init_ollama():
|
||||
# from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
# from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
||||
#
|
||||
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||
# request_timeout = float(
|
||||
# os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
|
||||
# )
|
||||
# Settings.embed_model = OllamaEmbedding(
|
||||
# base_url=base_url,
|
||||
# model_name=os.getenv("EMBEDDING_MODEL"),
|
||||
# )
|
||||
# Settings.llm = Ollama(
|
||||
# base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
|
||||
# )
|
||||
pass
|
||||
class ModelPlatform:
|
||||
@abstractmethod
|
||||
def model(self):
|
||||
pass
|
||||
|
||||
def init_xinference():
|
||||
base_url = os.getenv("BASE_URL")
|
||||
model = os.getenv("MODEL")
|
||||
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
|
||||
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
|
||||
@abstractmethod
|
||||
def embedding(self):
|
||||
pass
|
||||
|
||||
Settings.llm = Xinference(model, base_url, temperature, max_tokens)
|
||||
@abstractmethod
|
||||
def rerank(self):
|
||||
pass
|
||||
|
||||
embedding_base_url = os.getenv("EMBEDDING_BASE_URL")
|
||||
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url
|
||||
@register(ModelPlateCategory,'ollama')
|
||||
class OllamaPlatform(ModelPlatform):
|
||||
def model(self):
|
||||
#from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
#from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
||||
#
|
||||
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||
# request_timeout = float(
|
||||
# os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
|
||||
# )
|
||||
# Settings.llm = Ollama(
|
||||
# base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
|
||||
# )
|
||||
pass
|
||||
|
||||
embed_model_name = os.getenv("EMBEDDING_MODEL")
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
dimensions = int(dimensions) if dimensions is not None else None
|
||||
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
|
||||
def embedding(self):
|
||||
#from llama_index.embeddings.ollama import OllamaEmbedding
|
||||
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||
# Settings.embed_model = OllamaEmbedding(
|
||||
# base_url=base_url,
|
||||
# model_name=os.getenv("EMBEDDING_MODEL"),
|
||||
# )
|
||||
pass
|
||||
|
||||
def init_openai():
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
from llama_index.llms.openai import OpenAI
|
||||
def rerank(self):
|
||||
pass
|
||||
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
config = {
|
||||
"model": os.getenv("MODEL"),
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
Settings.llm = OpenAI(**config)
|
||||
@register(ModelPlateCategory,'xinference')
|
||||
class XinferencePlatform(ModelPlatform):
|
||||
def model(self):
|
||||
base_url = os.getenv("BASE_URL")
|
||||
model = os.getenv("MODEL")
|
||||
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
|
||||
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
|
||||
return Xinference(model, base_url, temperature, max_tokens)
|
||||
|
||||
def embedding(self):
|
||||
base_url = os.getenv("BASE_URL")
|
||||
embedding_base_url = os.getenv("EMBEDDING_BASE_URL")
|
||||
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url
|
||||
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
config = {
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
Settings.embed_model = OpenAIEmbedding(**config)
|
||||
embed_model_name = os.getenv("EMBEDDING_MODEL")
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
dimensions = int(dimensions) if dimensions is not None else None
|
||||
return XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
|
||||
|
||||
def rerank(self):
|
||||
rerank_model = os.getenv("RERANK_MODEL")
|
||||
rerank_url = os.getenv("RERANK_BASE_URL")
|
||||
rerank_top_n = os.getenv("RERANK_TOP_N")
|
||||
rerank_threshold = os.getenv("RERANK_THRESHOLD")
|
||||
postprocess = None
|
||||
if rerank_model is not None:
|
||||
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
|
||||
return postprocess
|
||||
|
||||
def init_dashscope():
|
||||
from llama_index.llms.dashscope import DashScope,DashScopeGenerationModels
|
||||
from llama_index.embeddings.dashscope import DashScopeEmbedding,DashScopeBatchTextEmbeddingModels,DashScopeTextEmbeddingType,DashScopeTextEmbeddingModels
|
||||
@register(ModelPlateCategory,'openai')
|
||||
class OpenAIPlatform(ModelPlatform):
|
||||
def model(self):
|
||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
from llama_index.llms.openai import OpenAI
|
||||
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
config = {
|
||||
"model": os.getenv("MODEL"),
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
Settings.llm = llm = DashScope(model_name=DashScopeGenerationModels.QWEN_MAX)
|
||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
config = {
|
||||
"model": os.getenv("MODEL"),
|
||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||
}
|
||||
return OpenAI(**config)
|
||||
|
||||
def embedding(self):
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
config = {
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
return OpenAIEmbedding(**config)
|
||||
|
||||
def rerank(self):
|
||||
pass
|
||||
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
config = {
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
Settings.embed_model = DashScopeEmbedding(model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
|
||||
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY)
|
||||
@register(ModelPlateCategory,'dashscope')
|
||||
class DashscopePlatform(ModelPlatform):
|
||||
def model(self):
|
||||
apikey = os.getenv('DASHSCOPE_API_KEY')
|
||||
modelName = os.getenv('MODEL')
|
||||
return CustomDashScope(model_name=modelName,api_key = apikey)
|
||||
|
||||
def init_azure_openai():
|
||||
# from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
# from llama_index.llms.azure_openai import AzureOpenAI
|
||||
#
|
||||
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
||||
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
||||
# max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
# dimensions = os.getenv("EMBEDDING_DIM")
|
||||
#
|
||||
# azure_config = {
|
||||
# "api_key": os.environ["AZURE_OPENAI_KEY"],
|
||||
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
||||
# or os.getenv("OPENAI_API_VERSION"),
|
||||
# }
|
||||
#
|
||||
# Settings.llm = AzureOpenAI(
|
||||
# model=os.getenv("MODEL"),
|
||||
# max_tokens=int(max_tokens) if max_tokens is not None else None,
|
||||
# temperature=float(temperature),
|
||||
# deployment_name=llm_deployment,
|
||||
# **azure_config,
|
||||
# )
|
||||
#
|
||||
# Settings.embed_model = AzureOpenAIEmbedding(
|
||||
# model=os.getenv("EMBEDDING_MODEL"),
|
||||
# dimensions=int(dimensions) if dimensions is not None else None,
|
||||
# deployment_name=embedding_deployment,
|
||||
# **azure_config,
|
||||
# )
|
||||
pass
|
||||
def embedding(self):
|
||||
from llama_index.embeddings.dashscope import DashScopeEmbedding,DashScopeTextEmbeddingType,DashScopeTextEmbeddingModels
|
||||
api_key = os.getenv('DASHSCOPE_API_KEY')
|
||||
modelName = os.getenv('EMBEDDING_MODEL')
|
||||
return DashScopeEmbedding(model_name=modelName,
|
||||
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY,api_key = api_key)
|
||||
|
||||
def init_fastembed():
|
||||
"""
|
||||
Use Qdrant Fastembed as the local embedding provider.
|
||||
"""
|
||||
# from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
||||
#
|
||||
# embed_model_map: Dict[str, str] = {
|
||||
# # Small and multilingual
|
||||
# "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
# # Large and multilingual
|
||||
# "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
|
||||
# }
|
||||
#
|
||||
# # This will download the model automatically if it is not already downloaded
|
||||
# Settings.embed_model = FastEmbedEmbedding(
|
||||
# model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
||||
# )
|
||||
pass
|
||||
def rerank(self):
|
||||
pass
|
||||
|
||||
@register(ModelPlateCategory,'azure-openai')
|
||||
class AzureOpenaiPlatform(ModelPlatform):
|
||||
def model(self):
|
||||
# from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
# from llama_index.llms.azure_openai import AzureOpenAI
|
||||
#
|
||||
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
||||
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
||||
# max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
# dimensions = os.getenv("EMBEDDING_DIM")
|
||||
#
|
||||
# azure_config = {
|
||||
# "api_key": os.environ["AZURE_OPENAI_KEY"],
|
||||
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
||||
# or os.getenv("OPENAI_API_VERSION"),
|
||||
# }
|
||||
#
|
||||
# return AzureOpenAI(
|
||||
# model=os.getenv("MODEL"),
|
||||
# max_tokens=int(max_tokens) if max_tokens is not None else None,
|
||||
# temperature=float(temperature),
|
||||
# deployment_name=llm_deployment,
|
||||
# **azure_config,
|
||||
# )
|
||||
pass
|
||||
|
||||
def init_groq():
|
||||
# from llama_index.llms.groq import Groq
|
||||
#
|
||||
# model_map: Dict[str, str] = {
|
||||
# "llama3-8b": "llama3-8b-8192",
|
||||
# "llama3-70b": "llama3-70b-8192",
|
||||
# "mixtral-8x7b": "mixtral-8x7b-32768",
|
||||
# }
|
||||
#
|
||||
# Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
|
||||
# # Groq does not provide embeddings, so we use FastEmbed instead
|
||||
# init_fastembed()
|
||||
pass
|
||||
def embedding(self):
|
||||
# from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||
#
|
||||
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
||||
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
||||
# max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||
# dimensions = os.getenv("EMBEDDING_DIM")
|
||||
#
|
||||
# azure_config = {
|
||||
# "api_key": os.environ["AZURE_OPENAI_KEY"],
|
||||
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
||||
# or os.getenv("OPENAI_API_VERSION"),
|
||||
# }
|
||||
# return AzureOpenAIEmbedding(
|
||||
# model=os.getenv("EMBEDDING_MODEL"),
|
||||
# dimensions=int(dimensions) if dimensions is not None else None,
|
||||
# deployment_name=embedding_deployment,
|
||||
# **azure_config,
|
||||
# )
|
||||
pass
|
||||
|
||||
def rerank(self):
|
||||
pass
|
||||
|
||||
def init_anthropic():
|
||||
# from llama_index.llms.anthropic import Anthropic
|
||||
#
|
||||
# model_map: Dict[str, str] = {
|
||||
# "claude-3-opus": "claude-3-opus-20240229",
|
||||
# "claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
# "claude-3-haiku": "claude-3-haiku-20240307",
|
||||
# "claude-2.1": "claude-2.1",
|
||||
# "claude-instant-1.2": "claude-instant-1.2",
|
||||
# }
|
||||
#
|
||||
# Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
|
||||
# # Anthropic does not provide embeddings, so we use FastEmbed instead
|
||||
# init_fastembed()
|
||||
pass
|
||||
@register(ModelPlateCategory,'fastembed')
|
||||
class FastembedPlatform(ModelPlatform):
|
||||
@abstractmethod
|
||||
def model(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embedding(self):
|
||||
# from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
||||
#
|
||||
# embed_model_map: Dict[str, str] = {
|
||||
# # Small and multilingual
|
||||
# "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
||||
# # Large and multilingual
|
||||
# "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
|
||||
# }
|
||||
#
|
||||
# # This will download the model automatically if it is not already downloaded
|
||||
# Settings.embed_model = FastEmbedEmbedding(
|
||||
# model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
||||
# )
|
||||
pass
|
||||
|
||||
def init_gemini():
|
||||
# from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
# from llama_index.llms.gemini import Gemini
|
||||
#
|
||||
# model_name = f"models/{os.getenv('MODEL')}"
|
||||
# embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
|
||||
#
|
||||
# Settings.llm = Gemini(model=model_name)
|
||||
# Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
|
||||
pass
|
||||
@abstractmethod
|
||||
def rerank(self):
|
||||
pass
|
||||
|
||||
def init_mistral():
|
||||
# from llama_index.embeddings.mistralai import MistralAIEmbedding
|
||||
# from llama_index.llms.mistralai import MistralAI
|
||||
#
|
||||
# Settings.llm = MistralAI(model=os.getenv("MODEL"))
|
||||
# Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||
pass
|
||||
@register(ModelPlateCategory,'groq')
|
||||
class GroqPlatform(ModelPlatform):
|
||||
@abstractmethod
|
||||
def model(self):
|
||||
# from llama_index.llms.groq import Groq
|
||||
#
|
||||
# model_map: Dict[str, str] = {
|
||||
# "llama3-8b": "llama3-8b-8192",
|
||||
# "llama3-70b": "llama3-70b-8192",
|
||||
# "mixtral-8x7b": "mixtral-8x7b-32768",
|
||||
# }
|
||||
#
|
||||
# Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
|
||||
# # Groq does not provide embeddings, so we use FastEmbed instead
|
||||
# init_fastembed()
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def embedding(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def rerank(self):
|
||||
pass
|
||||
|
||||
@register(ModelPlateCategory,'anthropic')
|
||||
class AnthropicPlatform(ModelPlatform):
|
||||
def model(self):
|
||||
# from llama_index.llms.anthropic import Anthropic
|
||||
#
|
||||
# model_map: Dict[str, str] = {
|
||||
# "claude-3-opus": "claude-3-opus-20240229",
|
||||
# "claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||
# "claude-3-haiku": "claude-3-haiku-20240307",
|
||||
# "claude-2.1": "claude-2.1",
|
||||
# "claude-instant-1.2": "claude-instant-1.2",
|
||||
# }
|
||||
#
|
||||
# Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
|
||||
# # Anthropic does not provide embeddings, so we use FastEmbed instead
|
||||
# init_fastembed()
|
||||
pass
|
||||
|
||||
def embedding(self):
|
||||
pass
|
||||
|
||||
def rerank(self):
|
||||
pass
|
||||
|
||||
@register(ModelPlateCategory,'gemini')
|
||||
class GeminiPlatform(ModelPlatform):
|
||||
def model(self):
|
||||
# from llama_index.llms.gemini import Gemini
|
||||
# model_name = f"models/{os.getenv('MODEL')}"
|
||||
# return Gemini(model=model_name)
|
||||
pass
|
||||
|
||||
def embedding(self):
|
||||
# from llama_index.embeddings.gemini import GeminiEmbedding
|
||||
# embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
|
||||
# return GeminiEmbedding(model_name=embed_model_name)
|
||||
pass
|
||||
|
||||
def rerank(self):
|
||||
pass
|
||||
|
||||
@register(ModelPlateCategory,'mistral')
|
||||
class MistralPlatform(ModelPlatform):
|
||||
def model(self):
|
||||
# from llama_index.llms.mistralai import MistralAI
|
||||
# return MistralAI(model=os.getenv("MODEL"))
|
||||
pass
|
||||
|
||||
def embedding(self):
|
||||
# from llama_index.embeddings.mistralai import MistralAIEmbedding
|
||||
# return MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||
pass
|
||||
|
||||
def rerank(self):
|
||||
pass
|
||||
|
||||
def init_ProjectInfo():
|
||||
prjObj = ProjectInfo()
|
||||
@@ -239,3 +332,6 @@ def init_ProjectInfo():
|
||||
for prjInfo in prjInfos:
|
||||
prjObj.add(prjInfo['name'],prjInfo['flag'])
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user