调整测试代码

This commit is contained in:
2024-08-19 08:59:45 +08:00
parent 2942730c9a
commit 176b49983a
+9 -6
View File
@@ -24,7 +24,8 @@ def main():
top_k = 5
filters = generate_filters([])
#question = "从工程属性表中查找工程名称"
question = "总算表中名称等于架空输电线路本体工程的金额?"
#question = "总算表中名称等于架空输电线路本体工程的金额?"
question = "工程监理费的金额是多少?"
# 创建向量检索查询工具
query_engine = index.as_query_engine(
similarity_top_k=top_k, filters=filters
@@ -35,18 +36,20 @@ def main():
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
sql_database = SQLDatabase(engine)
loader = CustomDatabaseReader(sql_database)
documents = loader.load_data(query="select * from ProjectProperties")
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
vectorIndex = VectorStoreIndex()
# 创建SQL查询工具
sql_obj_index = ObjectIndex.from_objects(
# sql_obj_index = ObjectIndex.from_objects(
# table_schema_objs,
# table_node_mapping,
# index_cls=VectorStoreIndex,
# )
sql_obj_index = ObjectIndex.from_objects_and_index(
table_schema_objs,
vectorIndex,
table_node_mapping,
index_cls=VectorStoreIndex,
)
query_result =vectorIndex.as_query_engine(