优化模型初始化代码
This commit is contained in:
+37
-23
@@ -4,34 +4,48 @@ SQL_DATABASE_URL=mysql+pymysql://zjinfo1:Dy2Bcr53Hm5xRkba@110.42.234.166:3306/zj
|
|||||||
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
#SQL_DATABASE_URL=mysql+pymysql://zjinfo2:GSKcziSdBixDXwcd@110.42.234.166:3306/zjinfo2
|
||||||
SQLITE_DATABASE_URL=sqlite:///./source.db
|
SQLITE_DATABASE_URL=sqlite:///./source.db
|
||||||
|
|
||||||
DASHSCOPE_API_KEY=sk-02c8540e86d84b7ca0e6f4f51bac6e60
|
# The number of similar embeddings to return when retrieving documents.
|
||||||
# The provider for the AI models to use.
|
TOP_K=10
|
||||||
MODEL_PROVIDER=dashscope
|
#--------------------------
|
||||||
# The name of LLM model to use.
|
# 是否启用混合检索
|
||||||
MODEL=qwen-max
|
HYBRID_ENABLED = false
|
||||||
|
# 混合检索阈值
|
||||||
|
HYBRID_ALPHA = 0.6
|
||||||
# 是否启用检索重排功能
|
# 是否启用检索重排功能
|
||||||
ENABLE_RERANK=true
|
RERANK_ENABLED=true
|
||||||
# Name of the embedding model to use.
|
|
||||||
EMBEDDING_MODEL=text-embedding-v2
|
|
||||||
|
|
||||||
# Dimension of the embedding model to use.
|
#---------- rerank- Xinference ----------------
|
||||||
|
RERANK_PROVIDER=xinference
|
||||||
|
RERANK_MODEL=bge-reranker-v2-m3
|
||||||
|
RERANK_BASE_URL=http://10.1.16.39:9995
|
||||||
|
RERANK_TOP_N=5
|
||||||
|
RERANK_THRESHOLD=0.3
|
||||||
|
|
||||||
|
#---------- model - Xinference ----------------
|
||||||
|
#MODEL_PROVIDER=xinference
|
||||||
|
#OPENAI_API_KEY=xinference
|
||||||
|
#BASE_URL=http://172.20.0.145:9995
|
||||||
|
#MODEL=Qwen2-72B-Instruct-GPTQ-Int8
|
||||||
|
## Temperature for sampling from the model.
|
||||||
|
#LLM_TEMPERATURE=0.1
|
||||||
|
|
||||||
|
#---------- model - dashscope ----------------
|
||||||
|
MODEL_PROVIDER=dashscope
|
||||||
|
DASHSCOPE_API_KEY=sk-221d2d202e104618a56002ce2e7dc0d0
|
||||||
|
MODEL=qwen-max
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
#---------- embedding - Xinference ----------------
|
||||||
|
EMBEDDING_PROVIDER=xinference
|
||||||
|
EMBEDDING_MODEL=bge-m3
|
||||||
|
EMBEDDING_BASE_URL=http://10.1.16.39:9995
|
||||||
EMBEDDING_DIM=1024
|
EMBEDDING_DIM=1024
|
||||||
|
|
||||||
|
|
||||||
# The questions to help users get started (multi-line).
|
# The questions to help users get started (multi-line).
|
||||||
CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容?
|
CONVERSATION_STARTERS=本工程指什么?\n总算表有哪些费用?\n项目划分哪些内容构成?\n其他费用表有哪些内容?
|
||||||
|
|
||||||
# The OpenAI API key to use.
|
|
||||||
# OPENAI_API_KEY=
|
|
||||||
|
|
||||||
# Temperature for sampling from the model.
|
|
||||||
# LLM_TEMPERATURE=
|
|
||||||
|
|
||||||
# Maximum number of tokens to generate.
|
|
||||||
# LLM_MAX_TOKENS=
|
|
||||||
|
|
||||||
# The number of similar embeddings to return when retrieving documents.
|
|
||||||
TOP_K=5
|
|
||||||
|
|
||||||
# The time in milliseconds to wait for the stream to return a response.
|
# The time in milliseconds to wait for the stream to return a response.
|
||||||
STREAM_TIMEOUT=60000
|
STREAM_TIMEOUT=60000
|
||||||
|
|
||||||
@@ -53,9 +67,8 @@ VECTOR_STORE_PATH=./storage_vector
|
|||||||
BM_RETRIEVER_PATH =./storage_bm
|
BM_RETRIEVER_PATH =./storage_bm
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
PHOENIX_API_KEY=123456
|
PHOENIX_API_KEY=123456
|
||||||
PHOENIX_URL=http://localhost:6006/v1/traces
|
PHOENIX_URL=http://10.1.6.103:6006/v1/traces
|
||||||
PHOENIX_PROJECT_NAME=ly_zjapp
|
PHOENIX_PROJECT_NAME=ly_zjapp
|
||||||
#OTEL_SERVICE_NAME=ly_zjapp
|
#OTEL_SERVICE_NAME=ly_zjapp
|
||||||
#OTEL_RESOURCE_ATTRIBUTES=openinference.project.name=ly_zjapp
|
#OTEL_RESOURCE_ATTRIBUTES=openinference.project.name=ly_zjapp
|
||||||
@@ -82,4 +95,5 @@ SYSTEM_PROMPT="You are a weather forecast agent. You help users to get the weath
|
|||||||
|
|
||||||
PRJTOJSON_URL = 'http://10.1.6.60:8092'
|
PRJTOJSON_URL = 'http://10.1.6.60:8092'
|
||||||
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
|
PROJECT_TITLE = "您好,我是博微工程理解小助手,您可以问我有关[线路工程]工程数据的相关问题!"
|
||||||
|
|
||||||
CHAT_UPLOAD_FILECACHE = "./output/uploaded"
|
CHAT_UPLOAD_FILECACHE = "./output/uploaded"
|
||||||
+14
-13
@@ -14,27 +14,28 @@ HYBRID_ALPHA = 0.6
|
|||||||
#--------------------------
|
#--------------------------
|
||||||
# 是否启用检索重排功能
|
# 是否启用检索重排功能
|
||||||
RERANK_ENABLED=true
|
RERANK_ENABLED=true
|
||||||
# Rerank model
|
|
||||||
|
#---------- rerank- Xinference ----------------
|
||||||
|
RERANK_PROVIDER=xinference
|
||||||
RERANK_MODEL=bge-reranker-v2-m3
|
RERANK_MODEL=bge-reranker-v2-m3
|
||||||
RERANK_BASE_URL=http://10.1.16.39:9995
|
RERANK_BASE_URL=http://10.1.16.39:9995
|
||||||
RERANK_TOP_N=5
|
RERANK_TOP_N=5
|
||||||
RERANK_THRESHOLD=0.3
|
RERANK_THRESHOLD=0.3
|
||||||
#---------- Xinference ----------------
|
|
||||||
# The provider for the AI models to use.
|
#---------- model - Xinference ----------------
|
||||||
MODEL_PROVIDER=xinference
|
MODEL_PROVIDER=xinference # The provider for the AI models to use.
|
||||||
# The OpenAI API key to use.
|
OPENAI_API_KEY=xinference # The OpenAI API key to use.
|
||||||
OPENAI_API_KEY=xinference
|
|
||||||
BASE_URL=http://10.1.0.142:9995
|
BASE_URL=http://10.1.0.142:9995
|
||||||
MODEL=Qwen2-72B-Instruct-GPTQ-Int8
|
MODEL=Qwen2-72B-Instruct-GPTQ-Int8
|
||||||
# Temperature for sampling from the model.
|
LLM_TEMPERATURE=0.1 # Temperature for sampling from the model.
|
||||||
LLM_TEMPERATURE=0.1
|
#LLM_MAX_TOKENS= # Maximum number of tokens to generate.
|
||||||
# Maximum number of tokens to generate.
|
|
||||||
#LLM_MAX_TOKENS=
|
|
||||||
# Name of the embedding model to use.
|
#---------- embedding - Xinference ----------------
|
||||||
|
EMBEDDING_PROVIDER=xinference
|
||||||
EMBEDDING_MODEL=bge-m3
|
EMBEDDING_MODEL=bge-m3
|
||||||
EMBEDDING_BASE_URL=http://10.1.16.39:9995
|
EMBEDDING_BASE_URL=http://10.1.16.39:9995
|
||||||
# Dimension of the embedding model to use.
|
EMBEDDING_DIM=1024 # Dimension of the embedding model to use.
|
||||||
EMBEDDING_DIM=1024
|
|
||||||
|
|
||||||
##---------- OpenAI ----------------
|
##---------- OpenAI ----------------
|
||||||
## The provider for the AI models to use.
|
## The provider for the AI models to use.
|
||||||
|
|||||||
@@ -24,14 +24,11 @@ from app.api.routers.services.fileServices import PrjFileLoadService,ChatFileSer
|
|||||||
from app.api.routers.services.suggestion import NextQuestionSuggestion
|
from app.api.routers.services.suggestion import NextQuestionSuggestion
|
||||||
import time
|
import time
|
||||||
from llama_index.core.settings import Settings
|
from llama_index.core.settings import Settings
|
||||||
from llama_index.core.callbacks import CallbackManager
|
|
||||||
|
|
||||||
logger = logging.getLogger("uvicorn")
|
logger = logging.getLogger("uvicorn")
|
||||||
|
|
||||||
v1_router = v = APIRouter()
|
v1_router = v = APIRouter()
|
||||||
|
|
||||||
Settings.llm.callback_manager = CallbackManager()
|
|
||||||
|
|
||||||
gEvent_handler = None
|
gEvent_handler = None
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
+286
-190
@@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
|
from abc import abstractmethod
|
||||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
from llama_index.core.settings import Settings
|
from llama_index.core.settings import Settings
|
||||||
from llama_index.llms.xinference import Xinference
|
from llama_index.llms.xinference import Xinference
|
||||||
@@ -9,229 +9,322 @@ from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
|
|||||||
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
|
||||||
from app.engine.loaders import getProjectInfos
|
from app.engine.loaders import getProjectInfos
|
||||||
from app.api.routers.request.base import ProjectInfo
|
from app.api.routers.request.base import ProjectInfo
|
||||||
|
from util.register import *
|
||||||
|
from llama_index.core.callbacks import CallbackManager
|
||||||
|
from modelProvide.customDashScope import CustomDashScope
|
||||||
|
|
||||||
|
ModelPlateCategory = '模型平台'
|
||||||
|
|
||||||
def get_node_postprocessors():
|
def get_node_postprocessors():
|
||||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||||
if rerank_enabled is None or rerank_enabled == 'False':
|
if rerank_enabled is None or rerank_enabled == 'False':
|
||||||
return []
|
return []
|
||||||
|
|
||||||
rerank_model = os.getenv("RERANK_MODEL")
|
Rerank_provider = os.getenv("RERANK_PROVIDER")
|
||||||
rerank_url = os.getenv("RERANK_BASE_URL")
|
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,Rerank_provider)
|
||||||
rerank_top_n = os.getenv("RERANK_TOP_N")
|
|
||||||
rerank_threshold = os.getenv("RERANK_THRESHOLD")
|
|
||||||
postprocess = None
|
postprocess = None
|
||||||
if rerank_model is not None:
|
if modelPaltCls is not None:
|
||||||
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
|
modelPalt:ModelPlatform = modelPaltCls()
|
||||||
|
postprocess = modelPalt.rerank()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
|
||||||
return postprocess
|
return postprocess
|
||||||
|
|
||||||
def init_settings():
|
def init_settings():
|
||||||
model_provider = os.getenv("MODEL_PROVIDER")
|
model_provider = os.getenv("MODEL_PROVIDER")
|
||||||
match model_provider:
|
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)
|
||||||
case "openai":
|
if modelPaltCls is not None:
|
||||||
init_openai()
|
modelPalt:ModelPlatform = modelPaltCls()
|
||||||
case "dashscope":
|
Settings.llm = modelPalt.model()
|
||||||
init_dashscope()
|
else:
|
||||||
case "groq":
|
raise ValueError(f"Invalid model provider: {model_provider}")
|
||||||
init_groq()
|
|
||||||
case "ollama":
|
embedding_provider = os.getenv("EMBEDDING_PROVIDER")
|
||||||
init_ollama()
|
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,embedding_provider)
|
||||||
case "anthropic":
|
if modelPalt is not None:
|
||||||
init_anthropic()
|
modelPalt:ModelPlatform = modelPaltCls()
|
||||||
case "gemini":
|
Settings.embed_model = modelPalt.embedding()
|
||||||
init_gemini()
|
else:
|
||||||
case "mistral":
|
raise ValueError(f"Invalid embedding provider: {embedding_provider}")
|
||||||
init_mistral()
|
|
||||||
case "azure-openai":
|
|
||||||
init_azure_openai()
|
|
||||||
case "t-systems":
|
|
||||||
from .llmhub import init_llmhub
|
|
||||||
init_llmhub()
|
|
||||||
case "xinference":
|
|
||||||
init_xinference()
|
|
||||||
case _:
|
|
||||||
raise ValueError(f"Invalid model provider: {model_provider}")
|
|
||||||
|
|
||||||
|
Settings.llm.callback_manager = CallbackManager()
|
||||||
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
|
Settings.chunk_size = int(os.getenv("CHUNK_SIZE", "1024"))
|
||||||
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
|
Settings.chunk_overlap = int(os.getenv("CHUNK_OVERLAP", "20"))
|
||||||
|
|
||||||
def init_ollama():
|
class ModelPlatform:
|
||||||
# from llama_index.embeddings.ollama import OllamaEmbedding
|
@abstractmethod
|
||||||
# from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
def model(self):
|
||||||
#
|
pass
|
||||||
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
|
||||||
# request_timeout = float(
|
|
||||||
# os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
|
|
||||||
# )
|
|
||||||
# Settings.embed_model = OllamaEmbedding(
|
|
||||||
# base_url=base_url,
|
|
||||||
# model_name=os.getenv("EMBEDDING_MODEL"),
|
|
||||||
# )
|
|
||||||
# Settings.llm = Ollama(
|
|
||||||
# base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
|
|
||||||
# )
|
|
||||||
pass
|
|
||||||
|
|
||||||
def init_xinference():
|
@abstractmethod
|
||||||
base_url = os.getenv("BASE_URL")
|
def embedding(self):
|
||||||
model = os.getenv("MODEL")
|
pass
|
||||||
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
|
|
||||||
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
|
|
||||||
|
|
||||||
Settings.llm = Xinference(model, base_url, temperature, max_tokens)
|
@abstractmethod
|
||||||
|
def rerank(self):
|
||||||
|
pass
|
||||||
|
|
||||||
embedding_base_url = os.getenv("EMBEDDING_BASE_URL")
|
@register(ModelPlateCategory,'ollama')
|
||||||
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url
|
class OllamaPlatform(ModelPlatform):
|
||||||
|
def model(self):
|
||||||
|
#from llama_index.embeddings.ollama import OllamaEmbedding
|
||||||
|
#from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama
|
||||||
|
#
|
||||||
|
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||||
|
# request_timeout = float(
|
||||||
|
# os.getenv("OLLAMA_REQUEST_TIMEOUT", DEFAULT_REQUEST_TIMEOUT)
|
||||||
|
# )
|
||||||
|
# Settings.llm = Ollama(
|
||||||
|
# base_url=base_url, model=os.getenv("MODEL"), request_timeout=request_timeout
|
||||||
|
# )
|
||||||
|
pass
|
||||||
|
|
||||||
embed_model_name = os.getenv("EMBEDDING_MODEL")
|
def embedding(self):
|
||||||
dimensions = os.getenv("EMBEDDING_DIM")
|
#from llama_index.embeddings.ollama import OllamaEmbedding
|
||||||
dimensions = int(dimensions) if dimensions is not None else None
|
# base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434"
|
||||||
Settings.embed_model = XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
|
# Settings.embed_model = OllamaEmbedding(
|
||||||
|
# base_url=base_url,
|
||||||
|
# model_name=os.getenv("EMBEDDING_MODEL"),
|
||||||
|
# )
|
||||||
|
pass
|
||||||
|
|
||||||
def init_openai():
|
def rerank(self):
|
||||||
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
pass
|
||||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
|
||||||
from llama_index.llms.openai import OpenAI
|
|
||||||
|
|
||||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
@register(ModelPlateCategory,'xinference')
|
||||||
config = {
|
class XinferencePlatform(ModelPlatform):
|
||||||
"model": os.getenv("MODEL"),
|
def model(self):
|
||||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
base_url = os.getenv("BASE_URL")
|
||||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
model = os.getenv("MODEL")
|
||||||
}
|
max_tokens = int(os.getenv("LLM_MAX_TOKENS")) if os.getenv("LLM_MAX_TOKENS") is not None else None
|
||||||
Settings.llm = OpenAI(**config)
|
temperature = float(os.getenv("LLM_TEMPERATURE", DEFAULT_XINFERENCE_TEMP))
|
||||||
|
return Xinference(model, base_url, temperature, max_tokens)
|
||||||
|
|
||||||
|
def embedding(self):
|
||||||
|
base_url = os.getenv("BASE_URL")
|
||||||
|
embedding_base_url = os.getenv("EMBEDDING_BASE_URL")
|
||||||
|
embedding_base_url = embedding_base_url if embedding_base_url != None and embedding_base_url != "" else base_url
|
||||||
|
|
||||||
dimensions = os.getenv("EMBEDDING_DIM")
|
embed_model_name = os.getenv("EMBEDDING_MODEL")
|
||||||
config = {
|
dimensions = os.getenv("EMBEDDING_DIM")
|
||||||
"model": os.getenv("EMBEDDING_MODEL"),
|
dimensions = int(dimensions) if dimensions is not None else None
|
||||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
return XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
|
||||||
}
|
|
||||||
Settings.embed_model = OpenAIEmbedding(**config)
|
def rerank(self):
|
||||||
|
rerank_model = os.getenv("RERANK_MODEL")
|
||||||
|
rerank_url = os.getenv("RERANK_BASE_URL")
|
||||||
|
rerank_top_n = os.getenv("RERANK_TOP_N")
|
||||||
|
rerank_threshold = os.getenv("RERANK_THRESHOLD")
|
||||||
|
postprocess = None
|
||||||
|
if rerank_model is not None:
|
||||||
|
postprocess = [XinferenceRerank(rerank_model, rerank_url, top_n=rerank_top_n, threshold=rerank_threshold)]
|
||||||
|
return postprocess
|
||||||
|
|
||||||
def init_dashscope():
|
@register(ModelPlateCategory,'openai')
|
||||||
from llama_index.llms.dashscope import DashScope,DashScopeGenerationModels
|
class OpenAIPlatform(ModelPlatform):
|
||||||
from llama_index.embeddings.dashscope import DashScopeEmbedding,DashScopeBatchTextEmbeddingModels,DashScopeTextEmbeddingType,DashScopeTextEmbeddingModels
|
def model(self):
|
||||||
|
from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
|
from llama_index.llms.openai import OpenAI
|
||||||
|
|
||||||
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||||
config = {
|
config = {
|
||||||
"model": os.getenv("MODEL"),
|
"model": os.getenv("MODEL"),
|
||||||
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
"temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)),
|
||||||
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
"max_tokens": int(max_tokens) if max_tokens is not None else None,
|
||||||
}
|
}
|
||||||
Settings.llm = llm = DashScope(model_name=DashScopeGenerationModels.QWEN_MAX)
|
return OpenAI(**config)
|
||||||
|
|
||||||
|
def embedding(self):
|
||||||
|
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||||
|
dimensions = os.getenv("EMBEDDING_DIM")
|
||||||
|
config = {
|
||||||
|
"model": os.getenv("EMBEDDING_MODEL"),
|
||||||
|
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||||
|
}
|
||||||
|
return OpenAIEmbedding(**config)
|
||||||
|
|
||||||
|
def rerank(self):
|
||||||
|
pass
|
||||||
|
|
||||||
dimensions = os.getenv("EMBEDDING_DIM")
|
@register(ModelPlateCategory,'dashscope')
|
||||||
config = {
|
class DashscopePlatform(ModelPlatform):
|
||||||
"model": os.getenv("EMBEDDING_MODEL"),
|
def model(self):
|
||||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
apikey = os.getenv('DASHSCOPE_API_KEY')
|
||||||
}
|
modelName = os.getenv('MODEL')
|
||||||
Settings.embed_model = DashScopeEmbedding(model_name=DashScopeTextEmbeddingModels.TEXT_EMBEDDING_V2,
|
return CustomDashScope(model_name=modelName,api_key = apikey)
|
||||||
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY)
|
|
||||||
|
|
||||||
def init_azure_openai():
|
def embedding(self):
|
||||||
# from llama_index.core.constants import DEFAULT_TEMPERATURE
|
from llama_index.embeddings.dashscope import DashScopeEmbedding,DashScopeTextEmbeddingType,DashScopeTextEmbeddingModels
|
||||||
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
api_key = os.getenv('DASHSCOPE_API_KEY')
|
||||||
# from llama_index.llms.azure_openai import AzureOpenAI
|
modelName = os.getenv('EMBEDDING_MODEL')
|
||||||
#
|
return DashScopeEmbedding(model_name=modelName,
|
||||||
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
text_type=DashScopeTextEmbeddingType.TEXT_TYPE_QUERY,api_key = api_key)
|
||||||
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
|
||||||
# max_tokens = os.getenv("LLM_MAX_TOKENS")
|
|
||||||
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
|
||||||
# dimensions = os.getenv("EMBEDDING_DIM")
|
|
||||||
#
|
|
||||||
# azure_config = {
|
|
||||||
# "api_key": os.environ["AZURE_OPENAI_KEY"],
|
|
||||||
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
|
||||||
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
|
||||||
# or os.getenv("OPENAI_API_VERSION"),
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# Settings.llm = AzureOpenAI(
|
|
||||||
# model=os.getenv("MODEL"),
|
|
||||||
# max_tokens=int(max_tokens) if max_tokens is not None else None,
|
|
||||||
# temperature=float(temperature),
|
|
||||||
# deployment_name=llm_deployment,
|
|
||||||
# **azure_config,
|
|
||||||
# )
|
|
||||||
#
|
|
||||||
# Settings.embed_model = AzureOpenAIEmbedding(
|
|
||||||
# model=os.getenv("EMBEDDING_MODEL"),
|
|
||||||
# dimensions=int(dimensions) if dimensions is not None else None,
|
|
||||||
# deployment_name=embedding_deployment,
|
|
||||||
# **azure_config,
|
|
||||||
# )
|
|
||||||
pass
|
|
||||||
|
|
||||||
def init_fastembed():
|
def rerank(self):
|
||||||
"""
|
pass
|
||||||
Use Qdrant Fastembed as the local embedding provider.
|
|
||||||
"""
|
|
||||||
# from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
|
||||||
#
|
|
||||||
# embed_model_map: Dict[str, str] = {
|
|
||||||
# # Small and multilingual
|
|
||||||
# "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
|
||||||
# # Large and multilingual
|
|
||||||
# "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# # This will download the model automatically if it is not already downloaded
|
|
||||||
# Settings.embed_model = FastEmbedEmbedding(
|
|
||||||
# model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
|
||||||
# )
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
@register(ModelPlateCategory,'azure-openai')
|
||||||
|
class AzureOpenaiPlatform(ModelPlatform):
|
||||||
|
def model(self):
|
||||||
|
# from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
|
# from llama_index.llms.azure_openai import AzureOpenAI
|
||||||
|
#
|
||||||
|
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
||||||
|
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
||||||
|
# max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||||
|
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||||
|
# dimensions = os.getenv("EMBEDDING_DIM")
|
||||||
|
#
|
||||||
|
# azure_config = {
|
||||||
|
# "api_key": os.environ["AZURE_OPENAI_KEY"],
|
||||||
|
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||||
|
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
||||||
|
# or os.getenv("OPENAI_API_VERSION"),
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# return AzureOpenAI(
|
||||||
|
# model=os.getenv("MODEL"),
|
||||||
|
# max_tokens=int(max_tokens) if max_tokens is not None else None,
|
||||||
|
# temperature=float(temperature),
|
||||||
|
# deployment_name=llm_deployment,
|
||||||
|
# **azure_config,
|
||||||
|
# )
|
||||||
|
pass
|
||||||
|
|
||||||
def init_groq():
|
def embedding(self):
|
||||||
# from llama_index.llms.groq import Groq
|
# from llama_index.core.constants import DEFAULT_TEMPERATURE
|
||||||
#
|
# from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
|
||||||
# model_map: Dict[str, str] = {
|
#
|
||||||
# "llama3-8b": "llama3-8b-8192",
|
# llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"]
|
||||||
# "llama3-70b": "llama3-70b-8192",
|
# embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"]
|
||||||
# "mixtral-8x7b": "mixtral-8x7b-32768",
|
# max_tokens = os.getenv("LLM_MAX_TOKENS")
|
||||||
# }
|
# temperature = os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)
|
||||||
#
|
# dimensions = os.getenv("EMBEDDING_DIM")
|
||||||
# Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
|
#
|
||||||
# # Groq does not provide embeddings, so we use FastEmbed instead
|
# azure_config = {
|
||||||
# init_fastembed()
|
# "api_key": os.environ["AZURE_OPENAI_KEY"],
|
||||||
pass
|
# "azure_endpoint": os.environ["AZURE_OPENAI_ENDPOINT"],
|
||||||
|
# "api_version": os.getenv("AZURE_OPENAI_API_VERSION")
|
||||||
|
# or os.getenv("OPENAI_API_VERSION"),
|
||||||
|
# }
|
||||||
|
# return AzureOpenAIEmbedding(
|
||||||
|
# model=os.getenv("EMBEDDING_MODEL"),
|
||||||
|
# dimensions=int(dimensions) if dimensions is not None else None,
|
||||||
|
# deployment_name=embedding_deployment,
|
||||||
|
# **azure_config,
|
||||||
|
# )
|
||||||
|
pass
|
||||||
|
|
||||||
|
def rerank(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def init_anthropic():
|
@register(ModelPlateCategory,'fastembed')
|
||||||
# from llama_index.llms.anthropic import Anthropic
|
class FastembedPlatform(ModelPlatform):
|
||||||
#
|
@abstractmethod
|
||||||
# model_map: Dict[str, str] = {
|
def model(self):
|
||||||
# "claude-3-opus": "claude-3-opus-20240229",
|
pass
|
||||||
# "claude-3-sonnet": "claude-3-sonnet-20240229",
|
|
||||||
# "claude-3-haiku": "claude-3-haiku-20240307",
|
|
||||||
# "claude-2.1": "claude-2.1",
|
|
||||||
# "claude-instant-1.2": "claude-instant-1.2",
|
|
||||||
# }
|
|
||||||
#
|
|
||||||
# Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
|
|
||||||
# # Anthropic does not provide embeddings, so we use FastEmbed instead
|
|
||||||
# init_fastembed()
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embedding(self):
|
||||||
|
# from llama_index.embeddings.fastembed import FastEmbedEmbedding
|
||||||
|
#
|
||||||
|
# embed_model_map: Dict[str, str] = {
|
||||||
|
# # Small and multilingual
|
||||||
|
# "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2",
|
||||||
|
# # Large and multilingual
|
||||||
|
# "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# # This will download the model automatically if it is not already downloaded
|
||||||
|
# Settings.embed_model = FastEmbedEmbedding(
|
||||||
|
# model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")]
|
||||||
|
# )
|
||||||
|
pass
|
||||||
|
|
||||||
def init_gemini():
|
@abstractmethod
|
||||||
# from llama_index.embeddings.gemini import GeminiEmbedding
|
def rerank(self):
|
||||||
# from llama_index.llms.gemini import Gemini
|
pass
|
||||||
#
|
|
||||||
# model_name = f"models/{os.getenv('MODEL')}"
|
|
||||||
# embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
|
|
||||||
#
|
|
||||||
# Settings.llm = Gemini(model=model_name)
|
|
||||||
# Settings.embed_model = GeminiEmbedding(model_name=embed_model_name)
|
|
||||||
pass
|
|
||||||
|
|
||||||
def init_mistral():
|
@register(ModelPlateCategory,'groq')
|
||||||
# from llama_index.embeddings.mistralai import MistralAIEmbedding
|
class GroqPlatform(ModelPlatform):
|
||||||
# from llama_index.llms.mistralai import MistralAI
|
@abstractmethod
|
||||||
#
|
def model(self):
|
||||||
# Settings.llm = MistralAI(model=os.getenv("MODEL"))
|
# from llama_index.llms.groq import Groq
|
||||||
# Settings.embed_model = MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
#
|
||||||
pass
|
# model_map: Dict[str, str] = {
|
||||||
|
# "llama3-8b": "llama3-8b-8192",
|
||||||
|
# "llama3-70b": "llama3-70b-8192",
|
||||||
|
# "mixtral-8x7b": "mixtral-8x7b-32768",
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# Settings.llm = Groq(model=model_map[os.getenv("MODEL")])
|
||||||
|
# # Groq does not provide embeddings, so we use FastEmbed instead
|
||||||
|
# init_fastembed()
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def embedding(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def rerank(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@register(ModelPlateCategory,'anthropic')
|
||||||
|
class AnthropicPlatform(ModelPlatform):
|
||||||
|
def model(self):
|
||||||
|
# from llama_index.llms.anthropic import Anthropic
|
||||||
|
#
|
||||||
|
# model_map: Dict[str, str] = {
|
||||||
|
# "claude-3-opus": "claude-3-opus-20240229",
|
||||||
|
# "claude-3-sonnet": "claude-3-sonnet-20240229",
|
||||||
|
# "claude-3-haiku": "claude-3-haiku-20240307",
|
||||||
|
# "claude-2.1": "claude-2.1",
|
||||||
|
# "claude-instant-1.2": "claude-instant-1.2",
|
||||||
|
# }
|
||||||
|
#
|
||||||
|
# Settings.llm = Anthropic(model=model_map[os.getenv("MODEL")])
|
||||||
|
# # Anthropic does not provide embeddings, so we use FastEmbed instead
|
||||||
|
# init_fastembed()
|
||||||
|
pass
|
||||||
|
|
||||||
|
def embedding(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def rerank(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@register(ModelPlateCategory,'gemini')
|
||||||
|
class GeminiPlatform(ModelPlatform):
|
||||||
|
def model(self):
|
||||||
|
# from llama_index.llms.gemini import Gemini
|
||||||
|
# model_name = f"models/{os.getenv('MODEL')}"
|
||||||
|
# return Gemini(model=model_name)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def embedding(self):
|
||||||
|
# from llama_index.embeddings.gemini import GeminiEmbedding
|
||||||
|
# embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}"
|
||||||
|
# return GeminiEmbedding(model_name=embed_model_name)
|
||||||
|
pass
|
||||||
|
|
||||||
|
def rerank(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@register(ModelPlateCategory,'mistral')
|
||||||
|
class MistralPlatform(ModelPlatform):
|
||||||
|
def model(self):
|
||||||
|
# from llama_index.llms.mistralai import MistralAI
|
||||||
|
# return MistralAI(model=os.getenv("MODEL"))
|
||||||
|
pass
|
||||||
|
|
||||||
|
def embedding(self):
|
||||||
|
# from llama_index.embeddings.mistralai import MistralAIEmbedding
|
||||||
|
# return MistralAIEmbedding(model_name=os.getenv("EMBEDDING_MODEL"))
|
||||||
|
pass
|
||||||
|
|
||||||
|
def rerank(self):
|
||||||
|
pass
|
||||||
|
|
||||||
def init_ProjectInfo():
|
def init_ProjectInfo():
|
||||||
prjObj = ProjectInfo()
|
prjObj = ProjectInfo()
|
||||||
@@ -239,3 +332,6 @@ def init_ProjectInfo():
|
|||||||
for prjInfo in prjInfos:
|
for prjInfo in prjInfos:
|
||||||
prjObj.add(prjInfo['name'],prjInfo['flag'])
|
prjObj.add(prjInfo['name'],prjInfo['flag'])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,58 @@
|
|||||||
|
from llama_index.llms.dashscope import DashScope
|
||||||
|
from llama_index.core.base.llms.types import LLMMetadata
|
||||||
|
|
||||||
|
class DashScopeGenerationModels:
|
||||||
|
"""DashScope Qwen serial models."""
|
||||||
|
|
||||||
|
QWEN_TURBO = "qwen-turbo"
|
||||||
|
QWEN_PLUS = "qwen-plus"
|
||||||
|
QWEN_MAX = "qwen-max"
|
||||||
|
QWEN_MAX_1201 = "qwen-max-1201"
|
||||||
|
QWEN_MAX_LONGCONTEXT = "qwen-max-longcontext"
|
||||||
|
QWEN2_MATH_72B_INSTRUCT = 'qwen2-math-72b-instruct'
|
||||||
|
|
||||||
|
DASHSCOPE_MODEL_META = {
|
||||||
|
DashScopeGenerationModels.QWEN_TURBO: {
|
||||||
|
"context_window": 1024 * 8,
|
||||||
|
"num_output": 1024 * 8,
|
||||||
|
"is_chat_model": True,
|
||||||
|
},
|
||||||
|
DashScopeGenerationModels.QWEN_PLUS: {
|
||||||
|
"context_window": 1024 * 32,
|
||||||
|
"num_output": 1024 * 32,
|
||||||
|
"is_chat_model": True,
|
||||||
|
},
|
||||||
|
DashScopeGenerationModels.QWEN_MAX: {
|
||||||
|
"context_window": 1024 * 8,
|
||||||
|
"num_output": 1024 * 8,
|
||||||
|
"is_chat_model": True,
|
||||||
|
},
|
||||||
|
DashScopeGenerationModels.QWEN_MAX_1201: {
|
||||||
|
"context_window": 1024 * 8,
|
||||||
|
"num_output": 1024 * 8,
|
||||||
|
"is_chat_model": True,
|
||||||
|
},
|
||||||
|
DashScopeGenerationModels.QWEN_MAX_LONGCONTEXT: {
|
||||||
|
"context_window": 1024 * 30,
|
||||||
|
"num_output": 1024 * 30,
|
||||||
|
"is_chat_model": True,
|
||||||
|
},
|
||||||
|
DashScopeGenerationModels.QWEN2_MATH_72B_INSTRUCT: {
|
||||||
|
"context_window": 1024 * 8,
|
||||||
|
"num_output": 1024 * 8,
|
||||||
|
"is_chat_model": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class CustomDashScope(DashScope):
|
||||||
|
@property
|
||||||
|
def metadata(self) -> LLMMetadata:
|
||||||
|
DASHSCOPE_MODEL_META[self.model_name]["num_output"] = (
|
||||||
|
self.max_tokens or DASHSCOPE_MODEL_META[self.model_name]["num_output"]
|
||||||
|
)
|
||||||
|
return LLMMetadata(
|
||||||
|
model_name=self.model_name, **DASHSCOPE_MODEL_META[self.model_name]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -0,0 +1,43 @@
|
|||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
|
class ClsRegister:
|
||||||
|
clsLst:Dict[str,Dict[str,str]] = {}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add(cls,catalog,name,obj) -> None:
|
||||||
|
if catalog in cls.clsLst:
|
||||||
|
registry = cls.clsLst[catalog]
|
||||||
|
registry[name] = obj
|
||||||
|
else:
|
||||||
|
registry:Dict[str,str] = {}
|
||||||
|
registry[name] = obj
|
||||||
|
cls.clsLst[catalog] = registry
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get(cls,catalog,name,fuzzy:bool=False) -> None:
|
||||||
|
if catalog in cls.clsLst:
|
||||||
|
registry = cls.clsLst[catalog]
|
||||||
|
for key,value in registry.items():
|
||||||
|
if fuzzy:
|
||||||
|
if key in name:
|
||||||
|
return value
|
||||||
|
else:
|
||||||
|
if key == name:
|
||||||
|
return value
|
||||||
|
return None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def getClsList(cls,catalog) -> None:
|
||||||
|
res_Lst = []
|
||||||
|
if catalog in cls.clsLst:
|
||||||
|
registry = cls.clsLst[catalog]
|
||||||
|
for key,value in registry.items():
|
||||||
|
res_Lst.append(value)
|
||||||
|
return res_Lst
|
||||||
|
|
||||||
|
|
||||||
|
def register(catalog,name):
|
||||||
|
def decorator(className):
|
||||||
|
ClsRegister.add(catalog,name,className)
|
||||||
|
return className
|
||||||
|
return decorator
|
||||||
Reference in New Issue
Block a user