Merge branch 'dev-web' of https://git.97id.com/ly/zjdataai-app into dev-web
This commit is contained in:
+12
-16
@@ -18,21 +18,6 @@ from llama_index.core.callbacks import CallbackManager
|
||||
|
||||
ModelPlateCategory = '模型平台'
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||
if rerank_enabled is None or rerank_enabled == 'False':
|
||||
return []
|
||||
|
||||
Rerank_provider = os.getenv("RERANK_PROVIDER")
|
||||
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,Rerank_provider)
|
||||
postprocess = None
|
||||
if modelPaltCls is not None:
|
||||
modelPalt:ModelPlatform = modelPaltCls()
|
||||
postprocess = modelPalt.rerank()
|
||||
else:
|
||||
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
|
||||
return postprocess
|
||||
|
||||
def init_settings():
|
||||
model_provider = os.getenv("MODEL_PROVIDER")
|
||||
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)
|
||||
@@ -91,7 +76,18 @@ class OllamaPlatform(ModelPlatform):
|
||||
pass
|
||||
|
||||
def rerank(self):
|
||||
pass
|
||||
from app.engine.rerank.ollamRerank import OllamaRerank
|
||||
modelpath = os.getcwd() + os.getenv('RERANK_MODEL')
|
||||
top_n = os.getenv('RERANK_TOP_N',5)
|
||||
threshold = float(os.getenv('RERANK_THRESHOLD',0.3))
|
||||
rerank = OllamaRerank(
|
||||
model=modelpath,
|
||||
top_n=top_n,
|
||||
device="cpu",
|
||||
score_threshold= threshold
|
||||
)
|
||||
return [rerank]
|
||||
|
||||
|
||||
@register(ModelPlateCategory,'xinference')
|
||||
class XinferencePlatform(ModelPlatform):
|
||||
|
||||
Reference in New Issue
Block a user