import os from typing import Dict from abc import abstractmethod from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.core.settings import Settings from llama_index.embeddings.xinference import XinferenceEmbedding #from llama_index.llms.xinference import Xinference from app.engine.model.xinference import XinferenceModel from app.engine.rerank.xinferenceRerank import CustomXinFerenceRerank from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP from app.engine.loaders import getProjectInfos from app.api.routers.request.base import ProjectInfo from modelProvide.customDashScope import CustomDashScope from util.register import * from llama_index.core.callbacks import CallbackManager ModelPlateCategory = '模型平台' def init_settings(): model_provider = os.getenv("MODEL_PROVIDER") modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider) if modelPaltCls is not None: modelPalt:ModelPlatform = modelPaltCls() Settings.llm = modelPalt.model() else: raise ValueError(f"Invalid model provider: {model_provider}") embedding_provider = os.getenv("EMBEDDING_PROVIDER") modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,embedding_provider) if modelPalt is not None: modelPalt:ModelPlatform = modelPaltCls() Settings.embed_model = modelPalt.embedding() else: raise ValueError(f"Invalid embedding provider: {embedding_provider}") Settings.llm.callback_manager = CallbackManager() Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024")) Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20")) class ModelPlatform: @abstractmethod def model(self): pass @abstractmethod def embedding(self): pass @abstractmethod def rerank(self): pass @register(ModelPlateCategory,'ollama') class OllamaPlatform(ModelPlatform): def model(self): from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434" request_timeout = float( os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT) ) Settings.llm = Ollama( base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout ) pass def embedding(self): #from llama_index.embeddings.ollama import OllamaEmbedding # base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434" # Settings.embed_model = OllamaEmbedding( # base_url=base_url, # model_name=os.getenv("EMBEDDING_MODEL"), # ) pass def rerank(self): from app.engine.rerank.ollamRerank import OllamaRerank modelpath = os.getcwd() + os.getenv('RERANK_MODEL') top_n = os.getenv('RERANK_TOP_N',5) threshold = float(os.getenv('RERANK_THRESHOLD',0.3)) rerank = OllamaRerank( model=modelpath, top_n=top_n, device="cpu", score_threshold= threshold ) return [rerank] @register(ModelPlateCategory,'xinference') class XinferencePlatform(ModelPlatform): def model(self): base_url = os.getenv("BASE_URL") model = os.getenv("MODEL") max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP)) return XinferenceModel(model_uid = model,endpoint = base_url,temperature = temperature,max_tokens = max_tokens) def embedding(self): base_url = os.getenv("BASE_URL") embedding_base_url = os.getenv("EMBEDDING_BASE_URL") embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url embed_model_name = os.getenv("EMBEDDING_MODEL") dimensions = os.getenv("EMBEDDING_DIM") dimensions = int(dimensions) if dimensions is not None else None return XinferenceEmbedding(embed_model_name, embedding_base_url) def rerank(self): rerank_model = os.getenv("RERANK_MODEL") rerank_url = os.getenv("RERANK_BASE_URL") rerank_top_n = os.getenv("RERANK_TOP_N") rerank_threshold = os.getenv("RERANK_THRESHOLD") postprocess = None if rerank_model is not None: postprocess = [CustomXinFerenceRerank(model = rerank_model, base_url = rerank_url, top_n=rerank_top_n,score_threshold=rerank_threshold)] return postprocess @register(ModelPlateCategory,'openai') class OpenAIPlatform(ModelPlatform): def model(self): from llama_index.core.constants import DEFAULT_TEMPERATURE from app.engine.model.siliconCloudOpenAI import SiliconCloudOpenAI return SiliconCloudOpenAI(api_key= os.getenv('OPENAI_API_KEY'), api_base= os.getenv('BASE_URL'), model= os.getenv('MODEL'), temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE))) def embedding(self): from llama_index.embeddings.openai import OpenAIEmbedding return OpenAIEmbedding(api_key=os.getenv('OPENAI_API_KEY'), api_base= os.getenv('EMBEDDING_BASE_URL'), model_name = os.getenv('EMBEDDING_MODEL'), dimensions= int(os.getenv("EMBEDDING_DIM"))) def rerank(self): from app.engine.rerank.siliconCloudRerank import SiliconCloudRerank postprocess = [SiliconCloudRerank(top_n = int(os.getenv('RERANK_TOP_N',5)), model = os.getenv('RERANK_MODEL'), base_url = os.getenv('RERANK_BASE_URL'), api_key = os.getenv('OPENAI_API_KEY') )] return postprocess @register(ModelPlateCategory,'dashscope') class DashscopePlatform(ModelPlatform): def model(self): apikey = os.getenv('DASHSCOPE_API_KEY') modelName = os.getenv('MODEL') return CustomDashScope(model_name=modelName,api_key = apikey) def embedding(self): from llama_index.embeddings.dashscope import DashScopeEmbedding,DashScopeTextEmbeddingType,DashScopeTextEmbeddingModels api_key = os.getenv('DASHSCOPE_API_KEY') modelName = os.getenv('EMBEDDING_MODEL') return DashScopeEmbedding(model_name=modelName, text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY,api_key = api_key) def rerank(self): pass @register(ModelPlateCategory,'azure-openai') class AzureOpenaiPlatform(ModelPlatform): def model(self): # from llama_index.core.constants import DEFAULT_TEMPERATURE # from llama_index.llms.azure_openai import AzureOpenAI # # llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"] # embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"] # max_tokens = os.getenv("LLM_MAX_TOKENS") # temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE) # dimensions = os.getenv("EMBEDDING_DIM") # # azure_config = { # "api_key": os.environ["AZURE_OPENAI_KEY"], # "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"], # "api_version": os.getenv("AZURE_OPENAI_API_VERSION") # or os.getenv("OPENAI_API_VERSION"), # } # # return AzureOpenAI( # model=os.getenv("MODEL"), # max_tokens=int(max_tokens) if max_tokens is not None else None, # temperature=float(temperature), # deployment_name=llm_deployment, # **azure_config, # ) pass def embedding(self): # from llama_index.core.constants import DEFAULT_TEMPERATURE # from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding # # llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"] # embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"] # max_tokens = os.getenv("LLM_MAX_TOKENS") # temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE) # dimensions = os.getenv("EMBEDDING_DIM") # # azure_config = { # "api_key": os.environ["AZURE_OPENAI_KEY"], # "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"], # "api_version": os.getenv("AZURE_OPENAI_API_VERSION") # or os.getenv("OPENAI_API_VERSION"), # } # return AzureOpenAIEmbedding( # model=os.getenv("EMBEDDING_MODEL"), # dimensions=int(dimensions) if dimensions is not None else None, # deployment_name=embedding_deployment, # **azure_config, # ) pass def rerank(self): pass @register(ModelPlateCategory,'fastembed') class FastembedPlatform(ModelPlatform): @abstractmethod def model(self): pass @abstractmethod def embedding(self): # from llama_index.embeddings.fastembed import FastEmbedEmbedding # # embed_model_map: Dict[str, str] = { # # Small and multilingual # "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", # # Large and multilingual # "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501 # } # # # This will download the model automatically if it is not already downloaded # Settings.embed_model = FastEmbedEmbedding( # model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")] # ) pass @abstractmethod def rerank(self): pass @register(ModelPlateCategory,'groq') class GroqPlatform(ModelPlatform): @abstractmethod def model(self): # from llama_index.llms.groq import Groq # # model_map: Dict[str, str] = { # "llama3-8b": "llama3-8b-8192", # "llama3-70b": "llama3-70b-8192", # "mixtral-8x7b": "mixtral-8x7b-32768", # } # # Settings.llm = Groq(model=model_map[os.getenv("MODEL")]) # # Groq does not provide embeddings, so we use FastEmbed instead # init_fastembed() pass @abstractmethod def embedding(self): pass @abstractmethod def rerank(self): pass @register(ModelPlateCategory,'anthropic') class AnthropicPlatform(ModelPlatform): def model(self): # from llama_index.llms.anthropic import Anthropic # # model_map: Dict[str, str] = { # "claude-3-opus": "claude-3-opus-20240229", # "claude-3-sonnet": "claude-3-sonnet-20240229", # "claude-3-haiku": "claude-3-haiku-20240307", # "claude-2.1": "claude-2.1", # "claude-instant-1.2": "claude-instant-1.2", # } # # Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")]) # # Anthropic does not provide embeddings, so we use FastEmbed instead # init_fastembed() pass def embedding(self): pass def rerank(self): pass @register(ModelPlateCategory,'gemini') class GeminiPlatform(ModelPlatform): def model(self): # from llama_index.llms.gemini import Gemini # model_name = f"models/{os.getenv('MODEL')}" # return Gemini(model=model_name) pass def embedding(self): # from llama_index.embeddings.gemini import GeminiEmbedding # embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}" # return GeminiEmbedding(model_name=embed_model_name) pass def rerank(self): pass @register(ModelPlateCategory,'mistral') class MistralPlatform(ModelPlatform): def model(self): # from llama_index.llms.mistralai import MistralAI # return MistralAI(model=os.getenv("MODEL")) pass def embedding(self): # from llama_index.embeddings.mistralai import MistralAIEmbedding # return MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL")) pass def rerank(self): pass def init_ProjectInfo(): prjObj = ProjectInfo() prjInfos:list[tuple] = getProjectInfos() for prjInfo in prjInfos: prjObj.add(prjInfo['name'],prjInfo['flag'])