Files
zjdataai-app/backend/app/settings.py
T

332 lines
12 KiB
Python

import os
from typing import Dict
from abc import abstractmethod
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.settings import Settings
from llama_index.embeddings.xinference import XinferenceEmbedding
#from llama_index.llms.xinference import Xinference
from app.engine.model.xinference import XinferenceModel
from app.engine.rerank.xinferenceRerank import CustomXinFerenceRerank
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from app.engine.loaders import getProjectInfos
from app.api.routers.request.base import ProjectInfo
from modelProvide.customDashScope import CustomDashScope
from util.register import *
from llama_index.core.callbacks import CallbackManager
ModelPlateCategory = '模型平台'
def init_settings():
model_provider = os.getenv("MODEL_PROVIDER")
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)
if modelPaltCls is not None:
modelPalt:ModelPlatform = modelPaltCls()
Settings.llm = modelPalt.model()
else:
raise ValueError(f"Invalid model provider: {model_provider}")
embedding_provider = os.getenv("EMBEDDING_PROVIDER")
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,embedding_provider)
if modelPalt is not None:
modelPalt:ModelPlatform = modelPaltCls()
Settings.embed_model = modelPalt.embedding()
else:
raise ValueError(f"Invalid embedding provider: {embedding_provider}")
Settings.llm.callback_manager = CallbackManager()
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
class ModelPlatform:
@abstractmethod
def model(self):
pass
@abstractmethod
def embedding(self):
pass
@abstractmethod
def rerank(self):
pass
@register(ModelPlateCategory,'ollama')
class OllamaPlatform(ModelPlatform):
def model(self):
from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
request_timeout = float(
os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
)
Settings.llm = Ollama(
base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
)
pass
def embedding(self):
#from llama_index.embeddings.ollama import OllamaEmbedding
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
# Settings.embed_model = OllamaEmbedding(
# base_url=base_url,
# model_name=os.getenv("EMBEDDING_MODEL"),
# )
pass
def rerank(self):
from app.engine.rerank.ollamRerank import OllamaRerank
modelpath = os.getcwd() + os.getenv('RERANK_MODEL')
top_n = os.getenv('RERANK_TOP_N',5)
threshold = float(os.getenv('RERANK_THRESHOLD',0.3))
rerank = OllamaRerank(
model=modelpath,
top_n=top_n,
device="cpu",
score_threshold= threshold
)
return [rerank]
@register(ModelPlateCategory,'xinference')
class XinferencePlatform(ModelPlatform):
def model(self):
base_url = os.getenv("BASE_URL")
model = os.getenv("MODEL")
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
return XinferenceModel(model_uid = model,endpoint = base_url,temperature = temperature,max_tokens = max_tokens)
def embedding(self):
base_url = os.getenv("BASE_URL")
embedding_base_url = os.getenv("EMBEDDING_BASE_URL")
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url
embed_model_name = os.getenv("EMBEDDING_MODEL")
dimensions = os.getenv("EMBEDDING_DIM")
dimensions = int(dimensions) if dimensions is not None else None
return XinferenceEmbedding(embed_model_name, embedding_base_url)
def rerank(self):
rerank_model = os.getenv("RERANK_MODEL")
rerank_url = os.getenv("RERANK_BASE_URL")
rerank_top_n = os.getenv("RERANK_TOP_N")
rerank_threshold = os.getenv("RERANK_THRESHOLD")
postprocess = None
if rerank_model is not None:
postprocess = [CustomXinFerenceRerank(model = rerank_model, base_url = rerank_url, top_n=rerank_top_n,score_threshold=rerank_threshold)]
return postprocess
@register(ModelPlateCategory,'openai')
class OpenAIPlatform(ModelPlatform):
def model(self):
from llama_index.core.constants import DEFAULT_TEMPERATURE
from app.engine.model.siliconCloudOpenAI import SiliconCloudOpenAI
return SiliconCloudOpenAI(api_key= os.getenv('OPENAI_API_KEY'),
api_base= os.getenv('BASE_URL'),
model= os.getenv('MODEL'),
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)))
def embedding(self):
from llama_index.embeddings.openai import OpenAIEmbedding
dimensions = os.getenv("EMBEDDING_DIM")
config = {
"model": os.getenv("EMBEDDING_MODEL"),
"dimensions": int(dimensions) if dimensions is not None else None,
}
return OpenAIEmbedding(**config)
def rerank(self):
pass
@register(ModelPlateCategory,'dashscope')
class DashscopePlatform(ModelPlatform):
def model(self):
apikey = os.getenv('DASHSCOPE_API_KEY')
modelName = os.getenv('MODEL')
return CustomDashScope(model_name=modelName,api_key = apikey)
def embedding(self):
from llama_index.embeddings.dashscope import DashScopeEmbedding,DashScopeTextEmbeddingType,DashScopeTextEmbeddingModels
api_key = os.getenv('DASHSCOPE_API_KEY')
modelName = os.getenv('EMBEDDING_MODEL')
return DashScopeEmbedding(model_name=modelName,
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY,api_key = api_key)
def rerank(self):
pass
@register(ModelPlateCategory,'azure-openai')
class AzureOpenaiPlatform(ModelPlatform):
def model(self):
# from llama_index.core.constants import DEFAULT_TEMPERATURE
# from llama_index.llms.azure_openai import AzureOpenAI
#
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
# max_tokens = os.getenv("LLM_MAX_TOKENS")
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
# dimensions = os.getenv("EMBEDDING_DIM")
#
# azure_config = {
# "api_key": os.environ["AZURE_OPENAI_KEY"],
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
# or os.getenv("OPENAI_API_VERSION"),
# }
#
# return AzureOpenAI(
# model=os.getenv("MODEL"),
# max_tokens=int(max_tokens) if max_tokens is not None else None,
# temperature=float(temperature),
# deployment_name=llm_deployment,
# **azure_config,
# )
pass
def embedding(self):
# from llama_index.core.constants import DEFAULT_TEMPERATURE
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
#
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
# max_tokens = os.getenv("LLM_MAX_TOKENS")
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
# dimensions = os.getenv("EMBEDDING_DIM")
#
# azure_config = {
# "api_key": os.environ["AZURE_OPENAI_KEY"],
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
# or os.getenv("OPENAI_API_VERSION"),
# }
# return AzureOpenAIEmbedding(
# model=os.getenv("EMBEDDING_MODEL"),
# dimensions=int(dimensions) if dimensions is not None else None,
# deployment_name=embedding_deployment,
# **azure_config,
# )
pass
def rerank(self):
pass
@register(ModelPlateCategory,'fastembed')
class FastembedPlatform(ModelPlatform):
@abstractmethod
def model(self):
pass
@abstractmethod
def embedding(self):
# from llama_index.embeddings.fastembed import FastEmbedEmbedding
#
# embed_model_map: Dict[str, str] = {
# # Small and multilingual
# "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
# # Large and multilingual
# "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
# }
#
# # This will download the model automatically if it is not already downloaded
# Settings.embed_model = FastEmbedEmbedding(
# model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
# )
pass
@abstractmethod
def rerank(self):
pass
@register(ModelPlateCategory,'groq')
class GroqPlatform(ModelPlatform):
@abstractmethod
def model(self):
# from llama_index.llms.groq import Groq
#
# model_map: Dict[str, str] = {
# "llama3-8b": "llama3-8b-8192",
# "llama3-70b": "llama3-70b-8192",
# "mixtral-8x7b": "mixtral-8x7b-32768",
# }
#
# Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
# # Groq does not provide embeddings, so we use FastEmbed instead
# init_fastembed()
pass
@abstractmethod
def embedding(self):
pass
@abstractmethod
def rerank(self):
pass
@register(ModelPlateCategory,'anthropic')
class AnthropicPlatform(ModelPlatform):
def model(self):
# from llama_index.llms.anthropic import Anthropic
#
# model_map: Dict[str, str] = {
# "claude-3-opus": "claude-3-opus-20240229",
# "claude-3-sonnet": "claude-3-sonnet-20240229",
# "claude-3-haiku": "claude-3-haiku-20240307",
# "claude-2.1": "claude-2.1",
# "claude-instant-1.2": "claude-instant-1.2",
# }
#
# Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
# # Anthropic does not provide embeddings, so we use FastEmbed instead
# init_fastembed()
pass
def embedding(self):
pass
def rerank(self):
pass
@register(ModelPlateCategory,'gemini')
class GeminiPlatform(ModelPlatform):
def model(self):
# from llama_index.llms.gemini import Gemini
# model_name = f"models/{os.getenv('MODEL')}"
# return Gemini(model=model_name)
pass
def embedding(self):
# from llama_index.embeddings.gemini import GeminiEmbedding
# embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
# return GeminiEmbedding(model_name=embed_model_name)
pass
def rerank(self):
pass
@register(ModelPlateCategory,'mistral')
class MistralPlatform(ModelPlatform):
def model(self):
# from llama_index.llms.mistralai import MistralAI
# return MistralAI(model=os.getenv("MODEL"))
pass
def embedding(self):
# from llama_index.embeddings.mistralai import MistralAIEmbedding
# return MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
pass
def rerank(self):
pass
def init_ProjectInfo():
prjObj = ProjectInfo()
prjInfos:list[tuple] = getProjectInfos()
for prjInfo in prjInfos:
prjObj.add(prjInfo['name'],prjInfo['flag'])