解决模块嵌套问题

This commit is contained in:
wanyaokun
2024-09-06 16:35:34 +08:00
parent 60b0f11ca2
commit 1c773924db
2 changed files with 18 additions and 17 deletions
+18 -2
View File
@@ -7,10 +7,26 @@ from llama_index.core.query_engine import RetrieverQueryEngine
from llama_index.core.response_synthesizers import ResponseMode from llama_index.core.response_synthesizers import ResponseMode
from llama_index.readers.database import DatabaseReader from llama_index.readers.database import DatabaseReader
from sqlalchemy import create_engine from sqlalchemy import create_engine
from util.register import *
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
from app.engine.retriever.HybridRetriever import HybridRetriever from app.engine.retriever.HybridRetriever import HybridRetriever
from app.settings import get_node_postprocessors
ModelPlateCategory = '模型平台'
def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED").title()
if rerank_enabled is None or rerank_enabled == 'False':
return []
Rerank_provider = os.getenv("RERANK_PROVIDER")
modelPaltCls = ClsRegister.get(ModelPlateCategory,Rerank_provider)
postprocess = None
if modelPaltCls is not None:
modelPalt = modelPaltCls()
postprocess = modelPalt.rerank()
else:
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
return postprocess
def makeDescriptionByEngine(sql_database:SQLDatabase): def makeDescriptionByEngine(sql_database:SQLDatabase):
reader = DatabaseReader(sql_database) reader = DatabaseReader(sql_database)
-15
View File
@@ -15,21 +15,6 @@ from modelProvide.customDashScope import CustomDashScope
ModelPlateCategory = '模型平台' ModelPlateCategory = '模型平台'
def get_node_postprocessors():
rerank_enabled = os.getenv("RERANK_ENABLED").title()
if rerank_enabled is None or rerank_enabled == 'False':
return []
Rerank_provider = os.getenv("RERANK_PROVIDER")
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,Rerank_provider)
postprocess = None
if modelPaltCls is not None:
modelPalt:ModelPlatform = modelPaltCls()
postprocess = modelPalt.rerank()
else:
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
return postprocess
def init_settings(): def init_settings():
model_provider = os.getenv("MODEL_PROVIDER") model_provider = os.getenv("MODEL_PROVIDER")
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider) modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)