67 lines
2.1 KiB
Python
67 lines
2.1 KiB
Python
import os
|
|
from ctypes import cast
|
|
|
|
from llama_index.core import VectorStoreIndex, SQLDatabase
|
|
from llama_index.core.indices.struct_store import SQLTableRetrieverQueryEngine
|
|
from llama_index.core.objects import SQLTableNodeMapping, ObjectIndex
|
|
from llama_index.readers.database import DatabaseReader
|
|
from sqlalchemy import create_engine
|
|
|
|
from app.api.routers.chat import generate_filters
|
|
from app.engine import get_index, makeDescriptionByEngine
|
|
from app.engine.loaders.db import CustomDatabaseReader
|
|
from app.engine.vectordb import get_vector_store
|
|
from app.observability import init_observability
|
|
from app.settings import init_settings
|
|
|
|
|
|
def main():
|
|
init_settings()
|
|
init_observability()
|
|
|
|
index = get_index()
|
|
|
|
top_k = 5
|
|
filters = generate_filters([])
|
|
#question = "从工程属性表中查找工程名称"
|
|
question = "总算表中名称等于架空输电线路本体工程的金额?"
|
|
# 创建向量检索查询工具
|
|
query_engine = index.as_query_engine(
|
|
similarity_top_k=top_k, filters=filters
|
|
)
|
|
query_result = query_engine.query(question)
|
|
print(query_result)
|
|
|
|
engine = create_engine(os.getenv("SQL_DATABASE_URL", ""))
|
|
sql_database = SQLDatabase(engine)
|
|
|
|
loader = CustomDatabaseReader(sql_database)
|
|
documents = loader.load_data(query="select * from ProjectProperties")
|
|
|
|
table_schema_objs = makeDescriptionByEngine(sql_database)
|
|
table_node_mapping = SQLTableNodeMapping(sql_database)
|
|
|
|
vectorIndex = VectorStoreIndex()
|
|
# 创建SQL查询工具
|
|
sql_obj_index = ObjectIndex.from_objects(
|
|
table_schema_objs,
|
|
table_node_mapping,
|
|
index_cls=VectorStoreIndex,
|
|
)
|
|
|
|
query_result =vectorIndex.as_query_engine(
|
|
similarity_top_k=top_k, filters=filters
|
|
).query(question)
|
|
print(query_result)
|
|
|
|
sql_query_engine = SQLTableRetrieverQueryEngine(sql_database,
|
|
sql_obj_index.as_retriever(similarity_top_k=1))
|
|
sql_query_result = sql_query_engine.query(question)
|
|
print(sql_query_result)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from phoenix.trace import using_project
|
|
|
|
with using_project("ly_zjapp_test") as obj:
|
|
main() |