更新xinference支持

This commit is contained in:
2024-09-10 08:42:12 +08:00
parent 7875e2cbcc
commit adce2a3809
+6 -3
View File
@@ -3,15 +3,18 @@ from typing import Dict
from abc import abstractmethod from abc import abstractmethod
from llama_index.core.constants import DEFAULT_TEMPERATURE from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.settings import Settings from llama_index.core.settings import Settings
from llama_index.embeddings.xinference import XinferenceEmbedding
from llama_index.llms.xinference import Xinference from llama_index.llms.xinference import Xinference
#from llama_index.embeddings.xinference import XinferenceEmbedding
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from llama_index.postprocessor.xinference_rerank import XinferenceRerank
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
from app.engine.loaders import getProjectInfos from app.engine.loaders import getProjectInfos
from app.api.routers.request.base import ProjectInfo from app.api.routers.request.base import ProjectInfo
from modelProvide.customDashScope import CustomDashScope
from util.register import * from util.register import *
from llama_index.core.callbacks import CallbackManager from llama_index.core.callbacks import CallbackManager
from modelProvide.customDashScope import CustomDashScope
ModelPlateCategory = '模型平台' ModelPlateCategory = '模型平台'
@@ -107,7 +110,7 @@ class XinferencePlatform(ModelPlatform):
embed_model_name = os.getenv("EMBEDDING_MODEL") embed_model_name = os.getenv("EMBEDDING_MODEL")
dimensions = os.getenv("EMBEDDING_DIM") dimensions = os.getenv("EMBEDDING_DIM")
dimensions = int(dimensions) if dimensions is not None else None dimensions = int(dimensions) if dimensions is not None else None
return XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions) return XinferenceEmbedding(embed_model_name, embedding_base_url)
def rerank(self): def rerank(self):
rerank_model = os.getenv("RERANK_MODEL") rerank_model = os.getenv("RERANK_MODEL")