解决模块嵌套问题
This commit is contained in:
@@ -7,10 +7,26 @@ from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from llama_index.core.response_synthesizers import ResponseMode
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
from util.register import *
|
||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||
from app.engine.retriever.HybridRetriever import HybridRetriever
|
||||
from app.settings import get_node_postprocessors
|
||||
|
||||
ModelPlateCategory = '模型平台'
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||
if rerank_enabled is None or rerank_enabled == 'False':
|
||||
return []
|
||||
|
||||
Rerank_provider = os.getenv("RERANK_PROVIDER")
|
||||
modelPaltCls = ClsRegister.get(ModelPlateCategory,Rerank_provider)
|
||||
postprocess = None
|
||||
if modelPaltCls is not None:
|
||||
modelPalt = modelPaltCls()
|
||||
postprocess = modelPalt.rerank()
|
||||
else:
|
||||
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
|
||||
return postprocess
|
||||
|
||||
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
||||
reader = DatabaseReader(sql_database)
|
||||
|
||||
@@ -15,21 +15,6 @@ from modelProvide.customDashScope import CustomDashScope
|
||||
|
||||
ModelPlateCategory = '模型平台'
|
||||
|
||||
def get_node_postprocessors():
|
||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||
if rerank_enabled is None or rerank_enabled == 'False':
|
||||
return []
|
||||
|
||||
Rerank_provider = os.getenv("RERANK_PROVIDER")
|
||||
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,Rerank_provider)
|
||||
postprocess = None
|
||||
if modelPaltCls is not None:
|
||||
modelPalt:ModelPlatform = modelPaltCls()
|
||||
postprocess = modelPalt.rerank()
|
||||
else:
|
||||
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
|
||||
return postprocess
|
||||
|
||||
def init_settings():
|
||||
model_provider = os.getenv("MODEL_PROVIDER")
|
||||
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)
|
||||
|
||||
Reference in New Issue
Block a user