优化模型初始化代码

This commit is contained in:
wanyaokun
2024-09-04 15:00:38 +08:00
parent 728ee06c5a
commit 97a486e631
6 changed files with 438 additions and 229 deletions
+37 -23
View File
@@ -4,34 +4,48 @@ SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zj
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2 #SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
SQLITE_DATABASE_URL=sqlite:///./source.db SQLITE_DATABASE_URL=sqlite:///./source.db
DASHSCOPE_API_KEY=sk-02c8540e86d84b7ca0e6f4f51bac6e60 # The number of similar embeddings to return when retrieving documents.
# The provider for the AI models to use. TOP_K=10
MODEL_PROVIDER=dashscope #--------------------------
# The name of LLM model to use. # 是否启用混合检索
MODEL=qwen-max HYBRID_ENABLED = false
# 混合检索阈值
HYBRID_ALPHA = 0.6
# 是否启用检索重排功能 # 是否启用检索重排功能
ENABLE_RERANK=true RERANK_ENABLED=true
# Name of the embedding model to use.
EMBEDDING_MODEL=text-embedding-v2
# Dimension of the embedding model to use. #---------- rerank- Xinference ----------------
RERANK_PROVIDER=xinference
RERANK_MODEL=bge-reranker-v2-m3
RERANK_BASE_URL=http://10.1.16.39:9995
RERANK_TOP_N=5
RERANK_THRESHOLD=0.3
#---------- model - Xinference ----------------
#MODEL_PROVIDER=xinference
#OPENAI_API_KEY=xinference
#BASE_URL=http://172.20.0.145:9995
#MODEL=Qwen2-72B-Instruct-GPTQ-Int8
## Temperature for sampling from the model.
#LLM_TEMPERATURE=0.1
#---------- model - dashscope ----------------
MODEL_PROVIDER=dashscope
DASHSCOPE_API_KEY=sk-221d2d202e104618a56002ce2e7dc0d0
MODEL=qwen-max
#---------- embedding - Xinference ----------------
EMBEDDING_PROVIDER=xinference
EMBEDDING_MODEL=bge-m3
EMBEDDING_BASE_URL=http://10.1.16.39:9995
EMBEDDING_DIM=1024 EMBEDDING_DIM=1024
# The questions to help users get started (multi-line). # The questions to help users get started (multi-line).
CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容? CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容?
# The OpenAI API key to use.
# OPENAI_API_KEY=
# Temperature for sampling from the model.
# LLM_TEMPERATURE=
# Maximum number of tokens to generate.
# LLM_MAX_TOKENS=
# The number of similar embeddings to return when retrieving documents.
TOP_K=5
# The time in milliseconds to wait for the stream to return a response. # The time in milliseconds to wait for the stream to return a response.
STREAM_TIMEOUT=60000 STREAM_TIMEOUT=60000
@@ -53,9 +67,8 @@ VECTOR_STORE_PATH=./storage_vector
BM_RETRIEVER_PATH =./storage_bm BM_RETRIEVER_PATH =./storage_bm
PHOENIX_API_KEY=123456 PHOENIX_API_KEY=123456
PHOENIX_URL=http://localhost:6006/v1/traces PHOENIX_URL=http://10.1.6.103:6006/v1/traces
PHOENIX_PROJECT_NAME=ly_zjapp PHOENIX_PROJECT_NAME=ly_zjapp
#OTEL_SERVICE_NAME=ly_zjapp #OTEL_SERVICE_NAME=ly_zjapp
#OTEL_RESOURCE_ATTRIBUTES=openinference.project.name=ly_zjapp #OTEL_RESOURCE_ATTRIBUTES=openinference.project.name=ly_zjapp
@@ -82,4 +95,5 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
PRJTOJSON_URL = 'http://10.1.6.60:8092' PRJTOJSON_URL = 'http://10.1.6.60:8092'
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!" PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
CHAT_UPLOAD_FILECACHE = "./output/uploaded" CHAT_UPLOAD_FILECACHE = "./output/uploaded"
+14 -13
View File
@@ -14,27 +14,28 @@ HYBRID_ALPHA = 0.6
#-------------------------- #--------------------------
# 是否启用检索重排功能 # 是否启用检索重排功能
RERANK_ENABLED=true RERANK_ENABLED=true
# Rerank model
#---------- rerank- Xinference ----------------
RERANK_PROVIDER=xinference
RERANK_MODEL=bge-reranker-v2-m3 RERANK_MODEL=bge-reranker-v2-m3
RERANK_BASE_URL=http://10.1.16.39:9995 RERANK_BASE_URL=http://10.1.16.39:9995
RERANK_TOP_N=5 RERANK_TOP_N=5
RERANK_THRESHOLD=0.3 RERANK_THRESHOLD=0.3
#---------- Xinference ----------------
# The provider for the AI models to use. #---------- model - Xinference ----------------
MODEL_PROVIDER=xinference MODEL_PROVIDER=xinference # The provider for the AI models to use.
# The OpenAI API key to use. OPENAI_API_KEY=xinference # The OpenAI API key to use.
OPENAI_API_KEY=xinference
BASE_URL=http://10.1.0.142:9995 BASE_URL=http://10.1.0.142:9995
MODEL=Qwen2-72B-Instruct-GPTQ-Int8 MODEL=Qwen2-72B-Instruct-GPTQ-Int8
# Temperature for sampling from the model. LLM_TEMPERATURE=0.1 # Temperature for sampling from the model.
LLM_TEMPERATURE=0.1 #LLM_MAX_TOKENS= # Maximum number of tokens to generate.
# Maximum number of tokens to generate.
#LLM_MAX_TOKENS=
# Name of the embedding model to use. #---------- embedding - Xinference ----------------
EMBEDDING_PROVIDER=xinference
EMBEDDING_MODEL=bge-m3 EMBEDDING_MODEL=bge-m3
EMBEDDING_BASE_URL=http://10.1.16.39:9995 EMBEDDING_BASE_URL=http://10.1.16.39:9995
# Dimension of the embedding model to use. EMBEDDING_DIM=1024 # Dimension of the embedding model to use.
EMBEDDING_DIM=1024
##---------- OpenAI ---------------- ##---------- OpenAI ----------------
## The provider for the AI models to use. ## The provider for the AI models to use.
-3
View File
@@ -24,14 +24,11 @@ from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileSer
from app.api.routers.services.suggestion import NextQuestionSuggestion from app.api.routers.services.suggestion import NextQuestionSuggestion
import time import time
from llama_index.core.settings import Settings from llama_index.core.settings import Settings
from llama_index.core.callbacks import CallbackManager
logger = logging.getLogger("uvicorn") logger = logging.getLogger("uvicorn")
v1_router = v = APIRouter() v1_router = v = APIRouter()
Settings.llm.callback_manager = CallbackManager()
gEvent_handler = None gEvent_handler = None
+286 -190
View File
@@ -1,6 +1,6 @@
import os import os
from typing import Dict from typing import Dict
from abc import abstractmethod
from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.settings import Settings from llama_index.core.settings import Settings
from llama_index.llms.xinference import Xinference from llama_index.llms.xinference import Xinference
@@ -9,229 +9,322 @@ from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from app.xinference.base import XinferenceEmbedding, XinferenceRerank from app.xinference.base import XinferenceEmbedding, XinferenceRerank
from app.engine.loaders import getProjectInfos from app.engine.loaders import getProjectInfos
from app.api.routers.request.base import ProjectInfo from app.api.routers.request.base import ProjectInfo
from util.register import *
from llama_index.core.callbacks import CallbackManager
from modelProvide.customDashScope import CustomDashScope
ModelPlateCategory = '模型平台'
def get_node_postprocessors(): def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED").title() rerank_enabled = os.getenv("RERANK_ENABLED").title()
if rerank_enabled is None or rerank_enabled == 'False': if rerank_enabled is None or rerank_enabled == 'False':
return [] return []
rerank_model = os.getenv("RERANK_MODEL") Rerank_provider = os.getenv("RERANK_PROVIDER")
rerank_url = os.getenv("RERANK_BASE_URL") modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,Rerank_provider)
rerank_top_n = os.getenv("RERANK_TOP_N")
rerank_threshold = os.getenv("RERANK_THRESHOLD")
postprocess = None postprocess = None
if rerank_model is not None: if modelPaltCls is not None:
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)] modelPalt:ModelPlatform = modelPaltCls()
postprocess = modelPalt.rerank()
else:
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
return postprocess return postprocess
def init_settings(): def init_settings():
model_provider = os.getenv("MODEL_PROVIDER") model_provider = os.getenv("MODEL_PROVIDER")
match model_provider: modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)
case "openai": if modelPaltCls is not None:
init_openai() modelPalt:ModelPlatform = modelPaltCls()
case "dashscope": Settings.llm = modelPalt.model()
init_dashscope() else:
case "groq": raise ValueError(f"Invalid model provider: {model_provider}")
init_groq()
case "ollama": embedding_provider = os.getenv("EMBEDDING_PROVIDER")
init_ollama() modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,embedding_provider)
case "anthropic": if modelPalt is not None:
init_anthropic() modelPalt:ModelPlatform = modelPaltCls()
case "gemini": Settings.embed_model = modelPalt.embedding()
init_gemini() else:
case "mistral": raise ValueError(f"Invalid embedding provider: {embedding_provider}")
init_mistral()
case "azure-openai":
init_azure_openai()
case "t-systems":
from .llmhub import init_llmhub
init_llmhub()
case "xinference":
init_xinference()
case _:
raise ValueError(f"Invalid model provider: {model_provider}")
Settings.llm.callback_manager = CallbackManager()
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024")) Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20")) Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
def init_ollama(): class ModelPlatform:
# from llama_index.embeddings.ollama import OllamaEmbedding @abstractmethod
# from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama def model(self):
# pass
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
# request_timeout = float(
# os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
# )
# Settings.embed_model = OllamaEmbedding(
# base_url=base_url,
# model_name=os.getenv("EMBEDDING_MODEL"),
# )
# Settings.llm = Ollama(
# base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
# )
pass
def init_xinference(): @abstractmethod
base_url = os.getenv("BASE_URL") def embedding(self):
model = os.getenv("MODEL") pass
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
Settings.llm = Xinference(model, base_url, temperature, max_tokens) @abstractmethod
def rerank(self):
pass
embedding_base_url = os.getenv("EMBEDDING_BASE_URL") @register(ModelPlateCategory,'ollama')
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url class OllamaPlatform(ModelPlatform):
def model(self):
#from llama_index.embeddings.ollama import OllamaEmbedding
#from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
#
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
# request_timeout = float(
# os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
# )
# Settings.llm = Ollama(
# base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
# )
pass
embed_model_name = os.getenv("EMBEDDING_MODEL") def embedding(self):
dimensions = os.getenv("EMBEDDING_DIM") #from llama_index.embeddings.ollama import OllamaEmbedding
dimensions = int(dimensions) if dimensions is not None else None # base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions) # Settings.embed_model = OllamaEmbedding(
# base_url=base_url,
# model_name=os.getenv("EMBEDDING_MODEL"),
# )
pass
def init_openai(): def rerank(self):
from llama_index.core.constants import DEFAULT_TEMPERATURE pass
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
max_tokens = os.getenv("LLM_MAX_TOKENS") @register(ModelPlateCategory,'xinference')
config = { class XinferencePlatform(ModelPlatform):
"model": os.getenv("MODEL"), def model(self):
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)), base_url = os.getenv("BASE_URL")
"max_tokens": int(max_tokens) if max_tokens is not None else None, model = os.getenv("MODEL")
} max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
Settings.llm = OpenAI(**config) temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
return Xinference(model, base_url, temperature, max_tokens)
def embedding(self):
base_url = os.getenv("BASE_URL")
embedding_base_url = os.getenv("EMBEDDING_BASE_URL")
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url
dimensions = os.getenv("EMBEDDING_DIM") embed_model_name = os.getenv("EMBEDDING_MODEL")
config = { dimensions = os.getenv("EMBEDDING_DIM")
"model": os.getenv("EMBEDDING_MODEL"), dimensions = int(dimensions) if dimensions is not None else None
"dimensions": int(dimensions) if dimensions is not None else None, return XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
}
Settings.embed_model = OpenAIEmbedding(**config) def rerank(self):
rerank_model = os.getenv("RERANK_MODEL")
rerank_url = os.getenv("RERANK_BASE_URL")
rerank_top_n = os.getenv("RERANK_TOP_N")
rerank_threshold = os.getenv("RERANK_THRESHOLD")
postprocess = None
if rerank_model is not None:
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
return postprocess
def init_dashscope(): @register(ModelPlateCategory,'openai')
from llama_index.llms.dashscope import DashScope,DashScopeGenerationModels class OpenAIPlatform(ModelPlatform):
from llama_index.embeddings.dashscope import DashScopeEmbedding,DashScopeBatchTextEmbeddingModels,DashScopeTextEmbeddingType,DashScopeTextEmbeddingModels def model(self):
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.llms.openai import OpenAI
max_tokens = os.getenv("LLM_MAX_TOKENS") max_tokens = os.getenv("LLM_MAX_TOKENS")
config = { config = {
"model": os.getenv("MODEL"), "model": os.getenv("MODEL"),
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)), "temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
"max_tokens": int(max_tokens) if max_tokens is not None else None, "max_tokens": int(max_tokens) if max_tokens is not None else None,
} }
Settings.llm = llm = DashScope(model_name=DashScopeGenerationModels.QWEN_MAX) return OpenAI(**config)
def embedding(self):
from llama_index.embeddings.openai import OpenAIEmbedding
dimensions = os.getenv("EMBEDDING_DIM")
config = {
"model": os.getenv("EMBEDDING_MODEL"),
"dimensions": int(dimensions) if dimensions is not None else None,
}
return OpenAIEmbedding(**config)
def rerank(self):
pass
dimensions = os.getenv("EMBEDDING_DIM") @register(ModelPlateCategory,'dashscope')
config = { class DashscopePlatform(ModelPlatform):
"model": os.getenv("EMBEDDING_MODEL"), def model(self):
"dimensions": int(dimensions) if dimensions is not None else None, apikey = os.getenv('DASHSCOPE_API_KEY')
} modelName = os.getenv('MODEL')
Settings.embed_model = DashScopeEmbedding(model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2, return CustomDashScope(model_name=modelName,api_key = apikey)
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY)
def init_azure_openai(): def embedding(self):
# from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.embeddings.dashscope import DashScopeEmbedding,DashScopeTextEmbeddingType,DashScopeTextEmbeddingModels
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding api_key = os.getenv('DASHSCOPE_API_KEY')
# from llama_index.llms.azure_openai import AzureOpenAI modelName = os.getenv('EMBEDDING_MODEL')
# return DashScopeEmbedding(model_name=modelName,
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"] text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY,api_key = api_key)
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
# max_tokens = os.getenv("LLM_MAX_TOKENS")
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
# dimensions = os.getenv("EMBEDDING_DIM")
#
# azure_config = {
# "api_key": os.environ["AZURE_OPENAI_KEY"],
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
# or os.getenv("OPENAI_API_VERSION"),
# }
#
# Settings.llm = AzureOpenAI(
# model=os.getenv("MODEL"),
# max_tokens=int(max_tokens) if max_tokens is not None else None,
# temperature=float(temperature),
# deployment_name=llm_deployment,
# **azure_config,
# )
#
# Settings.embed_model = AzureOpenAIEmbedding(
# model=os.getenv("EMBEDDING_MODEL"),
# dimensions=int(dimensions) if dimensions is not None else None,
# deployment_name=embedding_deployment,
# **azure_config,
# )
pass
def init_fastembed(): def rerank(self):
""" pass
Use Qdrant Fastembed as the local embedding provider.
"""
# from llama_index.embeddings.fastembed import FastEmbedEmbedding
#
# embed_model_map: Dict[str, str] = {
# # Small and multilingual
# "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
# # Large and multilingual
# "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
# }
#
# # This will download the model automatically if it is not already downloaded
# Settings.embed_model = FastEmbedEmbedding(
# model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
# )
pass
@register(ModelPlateCategory,'azure-openai')
class AzureOpenaiPlatform(ModelPlatform):
def model(self):
# from llama_index.core.constants import DEFAULT_TEMPERATURE
# from llama_index.llms.azure_openai import AzureOpenAI
#
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
# max_tokens = os.getenv("LLM_MAX_TOKENS")
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
# dimensions = os.getenv("EMBEDDING_DIM")
#
# azure_config = {
# "api_key": os.environ["AZURE_OPENAI_KEY"],
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
# or os.getenv("OPENAI_API_VERSION"),
# }
#
# return AzureOpenAI(
# model=os.getenv("MODEL"),
# max_tokens=int(max_tokens) if max_tokens is not None else None,
# temperature=float(temperature),
# deployment_name=llm_deployment,
# **azure_config,
# )
pass
def init_groq(): def embedding(self):
# from llama_index.llms.groq import Groq # from llama_index.core.constants import DEFAULT_TEMPERATURE
# # from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
# model_map: Dict[str, str] = { #
# "llama3-8b": "llama3-8b-8192", # llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
# "llama3-70b": "llama3-70b-8192", # embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
# "mixtral-8x7b": "mixtral-8x7b-32768", # max_tokens = os.getenv("LLM_MAX_TOKENS")
# } # temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
# # dimensions = os.getenv("EMBEDDING_DIM")
# Settings.llm = Groq(model=model_map[os.getenv("MODEL")]) #
# # Groq does not provide embeddings, so we use FastEmbed instead # azure_config = {
# init_fastembed() # "api_key": os.environ["AZURE_OPENAI_KEY"],
pass # "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
# or os.getenv("OPENAI_API_VERSION"),
# }
# return AzureOpenAIEmbedding(
# model=os.getenv("EMBEDDING_MODEL"),
# dimensions=int(dimensions) if dimensions is not None else None,
# deployment_name=embedding_deployment,
# **azure_config,
# )
pass
def rerank(self):
pass
def init_anthropic(): @register(ModelPlateCategory,'fastembed')
# from llama_index.llms.anthropic import Anthropic class FastembedPlatform(ModelPlatform):
# @abstractmethod
# model_map: Dict[str, str] = { def model(self):
# "claude-3-opus": "claude-3-opus-20240229", pass
# "claude-3-sonnet": "claude-3-sonnet-20240229",
# "claude-3-haiku": "claude-3-haiku-20240307",
# "claude-2.1": "claude-2.1",
# "claude-instant-1.2": "claude-instant-1.2",
# }
#
# Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
# # Anthropic does not provide embeddings, so we use FastEmbed instead
# init_fastembed()
pass
@abstractmethod
def embedding(self):
# from llama_index.embeddings.fastembed import FastEmbedEmbedding
#
# embed_model_map: Dict[str, str] = {
# # Small and multilingual
# "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
# # Large and multilingual
# "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
# }
#
# # This will download the model automatically if it is not already downloaded
# Settings.embed_model = FastEmbedEmbedding(
# model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
# )
pass
def init_gemini(): @abstractmethod
# from llama_index.embeddings.gemini import GeminiEmbedding def rerank(self):
# from llama_index.llms.gemini import Gemini pass
#
# model_name = f"models/{os.getenv('MODEL')}"
# embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
#
# Settings.llm = Gemini(model=model_name)
# Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
pass
def init_mistral(): @register(ModelPlateCategory,'groq')
# from llama_index.embeddings.mistralai import MistralAIEmbedding class GroqPlatform(ModelPlatform):
# from llama_index.llms.mistralai import MistralAI @abstractmethod
# def model(self):
# Settings.llm = MistralAI(model=os.getenv("MODEL")) # from llama_index.llms.groq import Groq
# Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL")) #
pass # model_map: Dict[str, str] = {
# "llama3-8b": "llama3-8b-8192",
# "llama3-70b": "llama3-70b-8192",
# "mixtral-8x7b": "mixtral-8x7b-32768",
# }
#
# Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
# # Groq does not provide embeddings, so we use FastEmbed instead
# init_fastembed()
pass
@abstractmethod
def embedding(self):
pass
@abstractmethod
def rerank(self):
pass
@register(ModelPlateCategory,'anthropic')
class AnthropicPlatform(ModelPlatform):
def model(self):
# from llama_index.llms.anthropic import Anthropic
#
# model_map: Dict[str, str] = {
# "claude-3-opus": "claude-3-opus-20240229",
# "claude-3-sonnet": "claude-3-sonnet-20240229",
# "claude-3-haiku": "claude-3-haiku-20240307",
# "claude-2.1": "claude-2.1",
# "claude-instant-1.2": "claude-instant-1.2",
# }
#
# Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
# # Anthropic does not provide embeddings, so we use FastEmbed instead
# init_fastembed()
pass
def embedding(self):
pass
def rerank(self):
pass
@register(ModelPlateCategory,'gemini')
class GeminiPlatform(ModelPlatform):
def model(self):
# from llama_index.llms.gemini import Gemini
# model_name = f"models/{os.getenv('MODEL')}"
# return Gemini(model=model_name)
pass
def embedding(self):
# from llama_index.embeddings.gemini import GeminiEmbedding
# embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
# return GeminiEmbedding(model_name=embed_model_name)
pass
def rerank(self):
pass
@register(ModelPlateCategory,'mistral')
class MistralPlatform(ModelPlatform):
def model(self):
# from llama_index.llms.mistralai import MistralAI
# return MistralAI(model=os.getenv("MODEL"))
pass
def embedding(self):
# from llama_index.embeddings.mistralai import MistralAIEmbedding
# return MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
pass
def rerank(self):
pass
def init_ProjectInfo(): def init_ProjectInfo():
prjObj = ProjectInfo() prjObj = ProjectInfo()
@@ -239,3 +332,6 @@ def init_ProjectInfo():
for prjInfo in prjInfos: for prjInfo in prjInfos:
prjObj.add(prjInfo['name'],prjInfo['flag']) prjObj.add(prjInfo['name'],prjInfo['flag'])
+58
View File
@@ -0,0 +1,58 @@
from llama_index.llms.dashscope import DashScope
from llama_index.core.base.llms.types import LLMMetadata
class DashScopeGenerationModels:
"""DashScope Qwen serial models."""
QWEN_TURBO = "qwen-turbo"
QWEN_PLUS = "qwen-plus"
QWEN_MAX = "qwen-max"
QWEN_MAX_1201 = "qwen-max-1201"
QWEN_MAX_LONGCONTEXT = "qwen-max-longcontext"
QWEN2_MATH_72B_INSTRUCT = 'qwen2-math-72b-instruct'
DASHSCOPE_MODEL_META = {
DashScopeGenerationModels.QWEN_TURBO: {
"context_window": 1024 * 8,
"num_output": 1024 * 8,
"is_chat_model": True,
},
DashScopeGenerationModels.QWEN_PLUS: {
"context_window": 1024 * 32,
"num_output": 1024 * 32,
"is_chat_model": True,
},
DashScopeGenerationModels.QWEN_MAX: {
"context_window": 1024 * 8,
"num_output": 1024 * 8,
"is_chat_model": True,
},
DashScopeGenerationModels.QWEN_MAX_1201: {
"context_window": 1024 * 8,
"num_output": 1024 * 8,
"is_chat_model": True,
},
DashScopeGenerationModels.QWEN_MAX_LONGCONTEXT: {
"context_window": 1024 * 30,
"num_output": 1024 * 30,
"is_chat_model": True,
},
DashScopeGenerationModels.QWEN2_MATH_72B_INSTRUCT: {
"context_window": 1024 * 8,
"num_output": 1024 * 8,
"is_chat_model": True,
},
}
class CustomDashScope(DashScope):
@property
def metadata(self) -> LLMMetadata:
DASHSCOPE_MODEL_META[self.model_name]["num_output"] = (
self.max_tokens or DASHSCOPE_MODEL_META[self.model_name]["num_output"]
)
return LLMMetadata(
model_name=self.model_name, **DASHSCOPE_MODEL_META[self.model_name]
)
+43
View File
@@ -0,0 +1,43 @@
from typing import Dict, List
class ClsRegister:
clsLst:Dict[str,Dict[str,str]] = {}
@classmethod
def add(cls,catalog,name,obj) -> None:
if catalog in cls.clsLst:
registry = cls.clsLst[catalog]
registry[name] = obj
else:
registry:Dict[str,str] = {}
registry[name] = obj
cls.clsLst[catalog] = registry
@classmethod
def get(cls,catalog,name,fuzzy:bool=False) -> None:
if catalog in cls.clsLst:
registry = cls.clsLst[catalog]
for key,value in registry.items():
if fuzzy:
if key in name:
return value
else:
if key == name:
return value
return None
@classmethod
def getClsList(cls,catalog) -> None:
res_Lst = []
if catalog in cls.clsLst:
registry = cls.clsLst[catalog]
for key,value in registry.items():
res_Lst.append(value)
return res_Lst
def register(catalog,name):
def decorator(className):
ClsRegister.add(catalog,name,className)
return className
return decorator