更新xinference支持

This commit is contained in:
2024-09-10 08:42:12 +08:00
parent 7875e2cbcc
commit adce2a3809
+6 -3
View File
@@ -3,15 +3,18 @@ from typing import Dict
from abc import abstractmethod
from llama_index.core.constants import DEFAULT_TEMPERATURE
from llama_index.core.settings import Settings
from llama_index.embeddings.xinference import XinferenceEmbedding
from llama_index.llms.xinference import Xinference
#from llama_index.embeddings.xinference import XinferenceEmbedding
from llama_index.llms.xinference.base import DEFAULT_XINFERENCE_TEMP
from llama_index.postprocessor.xinference_rerank import XinferenceRerank
from app.xinference.base import XinferenceEmbedding, XinferenceRerank
from app.engine.loaders import getProjectInfos
from app.api.routers.request.base import ProjectInfo
from modelProvide.customDashScope import CustomDashScope
from util.register import *
from llama_index.core.callbacks import CallbackManager
from modelProvide.customDashScope import CustomDashScope
ModelPlateCategory = '模型平台'
@@ -107,7 +110,7 @@ class XinferencePlatform(ModelPlatform):
embed_model_name = os.getenv("EMBEDDING_MODEL")
dimensions = os.getenv("EMBEDDING_DIM")
dimensions = int(dimensions) if dimensions is not None else None
return XinferenceEmbedding(embed_model_name, embedding_base_url, dimensions=dimensions)
return XinferenceEmbedding(embed_model_name, embedding_base_url)
def rerank(self):
rerank_model = os.getenv("RERANK_MODEL")