dev #1
@@ -168,14 +168,14 @@ def get_db_documents(configs: list[DBLoaderConfig]):
|
|||||||
engine = create_engine(entry.uri)
|
engine = create_engine(entry.uri)
|
||||||
sql_database = SQLDatabase(engine)
|
sql_database = SQLDatabase(engine)
|
||||||
|
|
||||||
table_schema_objs = makeDescriptionByEngine(sql_database)
|
# table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||||
table_node_mapping = SQLTableNodeMapping(sql_database)
|
# table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||||
|
#
|
||||||
nodes = table_node_mapping.to_nodes(table_schema_objs)
|
# nodes = table_node_mapping.to_nodes(table_schema_objs)
|
||||||
for node in nodes:
|
# for node in nodes:
|
||||||
node.metadata.update(metadata)
|
# node.metadata.update(metadata)
|
||||||
|
#
|
||||||
docs.extend(nodes)
|
# docs.extend(nodes)
|
||||||
|
|
||||||
queries = entry.queries or []
|
queries = entry.queries or []
|
||||||
loader = CustomDatabaseReader(sql_database)
|
loader = CustomDatabaseReader(sql_database)
|
||||||
|
|||||||
@@ -155,8 +155,8 @@ class XinferenceRerank(BaseNodePostprocessor):
|
|||||||
description="The model description from Xinference."
|
description="The model description from Xinference."
|
||||||
)
|
)
|
||||||
_generator: Any = PrivateAttr()
|
_generator: Any = PrivateAttr()
|
||||||
_model_uid: str
|
_model_uid: str = Field(description="The Xinference model to use.")
|
||||||
_endpoint: str
|
_endpoint: str = Field(description="The Xinference endpoint URL to use.")
|
||||||
model: str = Field(description="Dashscope rerank model name.")
|
model: str = Field(description="Dashscope rerank model name.")
|
||||||
top_n: int = Field(description="Top N nodes to return.")
|
top_n: int = Field(description="Top N nodes to return.")
|
||||||
threshold: float = Field(description="threshold nodes to return.")
|
threshold: float = Field(description="threshold nodes to return.")
|
||||||
|
|||||||
Reference in New Issue
Block a user