新增openai的向量,重排模型
This commit is contained in:
+11
-8
@@ -16,7 +16,6 @@ from modelProvide.customDashScope import CustomDashScope
|
||||
from util.register import *
|
||||
from llama_index.core.callbacks import CallbackManager
|
||||
|
||||
|
||||
ModelPlateCategory = '模型平台'
|
||||
|
||||
def init_settings():
|
||||
@@ -130,15 +129,19 @@ class OpenAIPlatform(ModelPlatform):
|
||||
|
||||
def embedding(self):
|
||||
from llama_index.embeddings.openai import OpenAIEmbedding
|
||||
dimensions = os.getenv("EMBEDDING_DIM")
|
||||
config = {
|
||||
"model": os.getenv("EMBEDDING_MODEL"),
|
||||
"dimensions": int(dimensions) if dimensions is not None else None,
|
||||
}
|
||||
return OpenAIEmbedding(**config)
|
||||
return OpenAIEmbedding(api_key=os.getenv('OPENAI_API_KEY'),
|
||||
api_base= os.getenv('EMBEDDING_BASE_URL'),
|
||||
model_name = os.getenv('EMBEDDING_MODEL'),
|
||||
dimensions= int(os.getenv("EMBEDDING_DIM")))
|
||||
|
||||
def rerank(self):
|
||||
pass
|
||||
from app.engine.rerank.siliconCloudRerank import SiliconCloudRerank
|
||||
postprocess = [SiliconCloudRerank(top_n = int(os.getenv('RERANK_TOP_N',5)),
|
||||
model = os.getenv('RERANK_MODEL'),
|
||||
base_url = os.getenv('RERANK_BASE_URL'),
|
||||
api_key = os.getenv('OPENAI_API_KEY')
|
||||
)]
|
||||
return postprocess
|
||||
|
||||
@register(ModelPlateCategory,'dashscope')
|
||||
class DashscopePlatform(ModelPlatform):
|
||||
|
||||
Reference in New Issue
Block a user