增加了判断是否使用数据库

This commit is contained in:
chentianrui
2024-08-28 09:45:01 +08:00
parent 0f7c900c1e
commit 8a5facb5b6
3 changed files with 91 additions and 59 deletions
+59 -22
View File
@@ -1,12 +1,15 @@
import logging
from typing import Any, List, Optional
from llama_index.core import Document
from llama_index.core import SQLDatabase, Document
from llama_index.readers.database import DatabaseReader
from pydantic import BaseModel
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
logger = logging.getLogger(__name__)
class CustomDatabaseReader:
class CustomDatabaseReader(DatabaseReader):
"""Simple Database reader.
Concatenates each row into Document used by LlamaIndex.
@@ -39,8 +42,8 @@ class CustomDatabaseReader:
def __init__(
self,
sql_database: Optional[Any] = None,
engine: Optional[Any] = None,
sql_database: Optional[SQLDatabase] = None,
engine: Optional[Engine] = None,
uri: Optional[str] = None,
scheme: Optional[str] = None,
host: Optional[str] = None,
@@ -52,24 +55,51 @@ class CustomDatabaseReader:
**kwargs: Any,
) -> None:
"""Initialize with parameters."""
# Setting the database-related properties to None
self.sql_database = None
self.uri = None
if sql_database:
self.sql_database = sql_database
elif engine:
self.sql_database = SQLDatabase(engine, *args, **kwargs)
elif uri:
self.uri = uri
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
elif scheme and host and port and user and password and dbname:
uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
self.uri = uri
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
else:
raise ValueError(
"You must provide either a SQLDatabase, "
"a SQL Alchemy Engine, a valid connection URI, or a valid "
"set of credentials."
)
def load_data(self, query: str, explanation: str) -> List[Document]:
"""Simulate loading data without a database connection.
def load_data(self, query: str) -> List[Document]:
"""Query and load data from the Database, returning a list of Documents.
Args:
query (str): Query parameter (not used).
explanation (str): Explanation to be included in the document.
query (str): Query parameter to filter tables and rows.
Returns:
List[Document]: A list of Document objects.
"""
dco_str = explanation + "\n"
# Simulate data without querying a real database
dco_str += "Simulated column1, Simulated column2\n"
dco_str += "Simulated data1, Simulated data2\n"
dco_str = ""
with self.sql_database.engine.connect() as connection:
if query is None:
raise ValueError("A query parameter is necessary to filter the data")
else:
result = connection.execute(text(query))
dco_str += ", ".join(
[f"{entry}" for entry in result.keys()]
) + "\n"
for item in result.fetchall():
# Fetch each item
record_str = ", ".join(
[f"{entry}" for col, entry in zip(result.keys(), item)]
)
dco_str += record_str + "\n"
doc = Document(text=dco_str)
doc.metadata["name"] = query
@@ -81,10 +111,10 @@ class DBLoaderConfig(BaseModel):
uri: str
queries: List[dict]
def get_db_documents(configs: list[DBLoaderConfig]):
def get_db_documents(configs: List[DBLoaderConfig]) -> List[Document]:
docs = []
if len(configs) == 0 or configs[0].uri == "":
if not configs or not configs[0].uri:
logger.warning(
f"Failed to load database, error message: uri is empty. Return as empty document list."
)
@@ -95,13 +125,20 @@ def get_db_documents(configs: list[DBLoaderConfig]):
}
for entry in configs:
# Skipping the database connection part
loader = CustomDatabaseReader()
engine = create_engine(entry.uri)
sql_database = SQLDatabase(engine)
loader = CustomDatabaseReader(sql_database)
for query_dict in entry.queries:
query = query_dict.get("sql", "")
explanation = query_dict.get("explanation", "")
logger.info(f"Loading data from database with query: {query}")
documents = loader.load_data(query=query, explanation=explanation)
documents = loader.load_data(query=query)
docs.extend(documents)
return docs
# 添加解释到元数据中
for doc in documents:
doc.metadata["explanation"] = explanation
doc.metadata.update(metadata) # 更新或添加额外的元数据
docs.append(doc)
return docs