将项目划分表按照业务拆分

This commit is contained in:
chentianrui
2024-08-23 15:05:48 +08:00
parent 5fc8375a06
commit d1117c73c4
5 changed files with 29 additions and 43 deletions
+15 -30
View File
@@ -2,17 +2,14 @@ import logging
from typing import Any, List, Optional
from llama_index.core import SQLDatabase, Document
from llama_index.core.objects import SQLTableSchema
from llama_index.core.readers.base import BaseReader
from llama_index.readers.database import DatabaseReader
from pydantic import BaseModel
from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
logger = logging.getLogger(__name__)
class CustomDatabaseReader(BaseReader):
class CustomDatabaseReader(DatabaseReader):
"""Simple Database reader.
Concatenates each row into Document used by LlamaIndex.
@@ -76,28 +73,30 @@ class CustomDatabaseReader(BaseReader):
"set of credentials."
)
def load_data(self, query: str) -> List[Document]:
def load_data(self, query: str, explanation: str) -> List[Document]:
"""Query and load data from the Database, returning a list of Documents.
Args:
query (str): Query parameter to filter tables and rows.
explanation (str): Explanation for the query to be included in the document.
Returns:
List[Document]: A list of Document objects.
"""
dco_str = ""
dco_str = explanation + "\n"
with self.sql_database.engine.connect() as connection:
if query is None:
raise ValueError("A query parameter is necessary to filter the data")
else:
result = connection.execute(text(query))
dco_str = ", ".join(
dco_str += ", ".join(
[f"{entry}" for entry in result.keys()]
)
) + "\n"
for item in result.fetchall():
# fetch each item
# Fetch each item
record_str = ", ".join(
[f"{entry}" for col, entry in zip(result.keys(), item)]
)
@@ -111,7 +110,7 @@ class CustomDatabaseReader(BaseReader):
class DBLoaderConfig(BaseModel):
uri: str
queries: List[str]
queries: List[dict]
def get_db_documents(configs: list[DBLoaderConfig]):
docs = []
@@ -123,33 +122,19 @@ def get_db_documents(configs: list[DBLoaderConfig]):
return docs
metadata = {
#'file_name':'',
'file_type':'application/booway.document.zj',
#'file_path':'',
#'file_size':'',
#'creation_date':'',
#'last_modified_date':'',
'file_type': 'application/booway.document.zj',
}
#from llama_index.readers.database import DatabaseReader
for entry in configs:
engine = create_engine(entry.uri)
sql_database = SQLDatabase(engine)
# table_schema_objs = makeDescriptionByEngine(sql_database)
# table_node_mapping = SQLTableNodeMapping(sql_database)
#
# nodes = table_node_mapping.to_nodes(table_schema_objs)
# for node in nodes:
# node.metadata.update(metadata)
#
# docs.extend(nodes)
queries = entry.queries or []
loader = CustomDatabaseReader(sql_database)
for query in queries:
for query_dict in entry.queries:
query = query_dict.get("sql", "")
explanation = query_dict.get("explanation", "")
logger.info(f"Loading data from database with query: {query}")
documents = loader.load_data(query=query)
documents = loader.load_data(query=query, explanation=explanation)
docs.extend(documents)
return docs