优化了提示词
This commit is contained in:
@@ -2,14 +2,17 @@ import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_index.core import SQLDatabase, Document
|
||||
from llama_index.core.objects import SQLTableSchema
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CustomDatabaseReader(DatabaseReader):
|
||||
class CustomDatabaseReader(BaseReader):
|
||||
"""Simple Database reader.
|
||||
|
||||
Concatenates each row into Document used by LlamaIndex.
|
||||
@@ -73,30 +76,28 @@ class CustomDatabaseReader(DatabaseReader):
|
||||
"set of credentials."
|
||||
)
|
||||
|
||||
def load_data(self, query: str, explanation: str) -> List[Document]:
|
||||
def load_data(self, query: str) -> List[Document]:
|
||||
"""Query and load data from the Database, returning a list of Documents.
|
||||
|
||||
Args:
|
||||
query (str): Query parameter to filter tables and rows.
|
||||
explanation (str): Explanation for the query to be included in the document.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
dco_str = explanation + "\n"
|
||||
|
||||
dco_str = ""
|
||||
with self.sql_database.engine.connect() as connection:
|
||||
if query is None:
|
||||
raise ValueError("A query parameter is necessary to filter the data")
|
||||
else:
|
||||
result = connection.execute(text(query))
|
||||
|
||||
dco_str += ", ".join(
|
||||
dco_str = ", ".join(
|
||||
[f"{entry}" for entry in result.keys()]
|
||||
) + "\n"
|
||||
)
|
||||
|
||||
for item in result.fetchall():
|
||||
# Fetch each item
|
||||
# fetch each item
|
||||
record_str = ", ".join(
|
||||
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
||||
)
|
||||
@@ -110,7 +111,7 @@ class CustomDatabaseReader(DatabaseReader):
|
||||
|
||||
class DBLoaderConfig(BaseModel):
|
||||
uri: str
|
||||
queries: List[dict]
|
||||
queries: List[str]
|
||||
|
||||
def get_db_documents(configs: list[DBLoaderConfig]):
|
||||
docs = []
|
||||
@@ -122,19 +123,33 @@ def get_db_documents(configs: list[DBLoaderConfig]):
|
||||
return docs
|
||||
|
||||
metadata = {
|
||||
'file_type': 'application/booway.document.zj',
|
||||
#'file_name':'',
|
||||
'file_type':'application/booway.document.zj',
|
||||
#'file_path':'',
|
||||
#'file_size':'',
|
||||
#'creation_date':'',
|
||||
#'last_modified_date':'',
|
||||
}
|
||||
|
||||
#from llama_index.readers.database import DatabaseReader
|
||||
for entry in configs:
|
||||
engine = create_engine(entry.uri)
|
||||
sql_database = SQLDatabase(engine)
|
||||
|
||||
# table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
# table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
#
|
||||
# nodes = table_node_mapping.to_nodes(table_schema_objs)
|
||||
# for node in nodes:
|
||||
# node.metadata.update(metadata)
|
||||
#
|
||||
# docs.extend(nodes)
|
||||
|
||||
queries = entry.queries or []
|
||||
loader = CustomDatabaseReader(sql_database)
|
||||
for query_dict in entry.queries:
|
||||
query = query_dict.get("sql", "")
|
||||
explanation = query_dict.get("explanation", "")
|
||||
for query in queries:
|
||||
logger.info(f"Loading data from database with query: {query}")
|
||||
documents = loader.load_data(query=query, explanation=explanation)
|
||||
documents = loader.load_data(query=query)
|
||||
|
||||
docs.extend(documents)
|
||||
return docs
|
||||
|
||||
Reference in New Issue
Block a user