diff --git a/backend/app/engine/engine.py b/backend/app/engine/engine.py index bde2b9d..138ad53 100644 --- a/backend/app/engine/engine.py +++ b/backend/app/engine/engine.py @@ -7,10 +7,26 @@ from llama_index.core.query_engine import RetrieverQueryEngine from llama_index.core.response_synthesizers import ResponseMode from llama_index.readers.database import DatabaseReader from sqlalchemy import create_engine - +from util.register import * from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template from app.engine.retriever.HybridRetriever import HybridRetriever -from app.settings import get_node_postprocessors + +ModelPlateCategory = '模型平台' + +def get_node_postprocessors(): + rerank_enabled = os.getenv("RERANK_ENABLED").title() + if rerank_enabled is None or rerank_enabled == 'False': + return [] + + Rerank_provider = os.getenv("RERANK_PROVIDER") + modelPaltCls = ClsRegister.get(ModelPlateCategory,Rerank_provider) + postprocess = None + if modelPaltCls is not None: + modelPalt = modelPaltCls() + postprocess = modelPalt.rerank() + else: + raise ValueError(f"Invalid rerank provider: {Rerank_provider}") + return postprocess def makeDescriptionByEngine(sql_database:SQLDatabase): reader = DatabaseReader(sql_database) diff --git a/backend/app/settings.py b/backend/app/settings.py index 57f098c..12e200c 100644 --- a/backend/app/settings.py +++ b/backend/app/settings.py @@ -15,21 +15,6 @@ from modelProvide.customDashScope import CustomDashScope ModelPlateCategory = '模型平台' -def get_node_postprocessors(): - rerank_enabled = os.getenv("RERANK_ENABLED").title() - if rerank_enabled is None or rerank_enabled == 'False': - return [] - - Rerank_provider = os.getenv("RERANK_PROVIDER") - modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,Rerank_provider) - postprocess = None - if modelPaltCls is not None: - modelPalt:ModelPlatform = modelPaltCls() - postprocess = modelPalt.rerank() - else: - raise ValueError(f"Invalid rerank provider: {Rerank_provider}") - return postprocess - def init_settings(): model_provider = os.getenv("MODEL_PROVIDER") modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)