解决模块嵌套问题
This commit is contained in:
@@ -7,10 +7,26 @@ from llama_index.core.query_engine import RetrieverQueryEngine
|
|||||||
from llama_index.core.response_synthesizers import ResponseMode
|
from llama_index.core.response_synthesizers import ResponseMode
|
||||||
from llama_index.readers.database import DatabaseReader
|
from llama_index.readers.database import DatabaseReader
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine
|
||||||
|
from util.register import *
|
||||||
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
from app.engine.prompt import text_qa_template, refine_template, summary_template, simple_template
|
||||||
from app.engine.retriever.HybridRetriever import HybridRetriever
|
from app.engine.retriever.HybridRetriever import HybridRetriever
|
||||||
from app.settings import get_node_postprocessors
|
|
||||||
|
ModelPlateCategory = '模型平台'
|
||||||
|
|
||||||
|
def get_node_postprocessors():
|
||||||
|
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
||||||
|
if rerank_enabled is None or rerank_enabled == 'False':
|
||||||
|
return []
|
||||||
|
|
||||||
|
Rerank_provider = os.getenv("RERANK_PROVIDER")
|
||||||
|
modelPaltCls = ClsRegister.get(ModelPlateCategory,Rerank_provider)
|
||||||
|
postprocess = None
|
||||||
|
if modelPaltCls is not None:
|
||||||
|
modelPalt = modelPaltCls()
|
||||||
|
postprocess = modelPalt.rerank()
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
|
||||||
|
return postprocess
|
||||||
|
|
||||||
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
def makeDescriptionByEngine(sql_database:SQLDatabase):
|
||||||
reader = DatabaseReader(sql_database)
|
reader = DatabaseReader(sql_database)
|
||||||
|
|||||||
@@ -15,21 +15,6 @@ from modelProvide.customDashScope import CustomDashScope
|
|||||||
|
|
||||||
ModelPlateCategory = '模型平台'
|
ModelPlateCategory = '模型平台'
|
||||||
|
|
||||||
def get_node_postprocessors():
|
|
||||||
rerank_enabled = os.getenv("RERANK_ENABLED").title()
|
|
||||||
if rerank_enabled is None or rerank_enabled == 'False':
|
|
||||||
return []
|
|
||||||
|
|
||||||
Rerank_provider = os.getenv("RERANK_PROVIDER")
|
|
||||||
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,Rerank_provider)
|
|
||||||
postprocess = None
|
|
||||||
if modelPaltCls is not None:
|
|
||||||
modelPalt:ModelPlatform = modelPaltCls()
|
|
||||||
postprocess = modelPalt.rerank()
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid rerank provider: {Rerank_provider}")
|
|
||||||
return postprocess
|
|
||||||
|
|
||||||
def init_settings():
|
def init_settings():
|
||||||
model_provider = os.getenv("MODEL_PROVIDER")
|
model_provider = os.getenv("MODEL_PROVIDER")
|
||||||
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)
|
modelPaltCls:ModelPlatform = ClsRegister.get(ModelPlateCategory,model_provider)
|
||||||
|
|||||||
Reference in New Issue
Block a user