Files
zjdataai-app/backend/tests/query.py
T
2024-08-28 19:58:37 +08:00

72 lines
2.3 KiB
Python

import os
from ctypes import cast
from llama_index.core import VectorStoreIndex, SQLDatabase
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
from llama_index.readers.database import DatabaseReader
from sqlalchemy import create_engine
from app.api.routers.chat import generate_filters
from app.engine import get_index, makeDescriptionByEngine
from app.engine.loaders.db import CustomDatabaseReader
from app.engine.vectordb import get_vector_store
from app.observability import init_observability
from app.settings import init_settings
def main():
init_settings()
init_observability()
indexs = get_index()
if len(indexs) > 0:
index = list(indexs.values())[0]
top_k = 5
filters = generate_filters([])
#question = "从工程属性表中查找工程名称"
#question = "总算表中名称等于架空输电线路本体工程的金额?"
question = "工程监理费的金额是多少?"
# 创建向量检索查询工具
query_engine = index.as_query_engine(
similarity_top_k=top_k, filters=filters
)
query_result = query_engine.query(question)
print(query_result)
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
sql_database = SQLDatabase(engine)
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
vectorIndex = VectorStoreIndex()
# 创建SQL查询工具
# sql_obj_index = ObjectIndex.from_objects(
# table_schema_objs,
# table_node_mapping,
# index_cls=VectorStoreIndex,
# )
sql_obj_index = ObjectIndex.from_objects_and_index(
table_schema_objs,
vectorIndex,
table_node_mapping,
)
query_result =vectorIndex.as_query_engine(
similarity_top_k=top_k, filters=filters
).query(question)
print(query_result)
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
sql_obj_index.as_retriever(similarity_top_k=1))
sql_query_result = sql_query_engine.query(question)
print(sql_query_result)
if __name__ == "__main__":
from phoenix.trace import using_project
with using_project("ly_zjapp_test") as obj:
main()