This commit is contained in:
2024-08-28 11:49:22 +08:00
4 changed files with 55 additions and 65 deletions
+18 -17
View File
@@ -1,5 +1,4 @@
import logging
import yaml
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
@@ -17,24 +16,26 @@ def load_configs():
def get_documents():
documents = []
config = load_configs()
if config is None or len(config.items()) == 0:
return documents
return documents
for loader_type, loader_config in config.items():
logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}"
)
if loader_config.get('enable', True): # 检查 enable 字段
logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}"
)
loader_config = loader_config or []
match loader_type:
case "file":
document = get_file_documents(FileLoaderConfig(**loader_config))
case "web":
document = get_web_documents(WebLoaderConfig(**loader_config))
case "db":
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
case _:
raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document)
loader_config = loader_config or []
match loader_type:
case "file":
document = get_file_documents(FileLoaderConfig(**loader_config))
case "web":
document = get_web_documents(WebLoaderConfig(**loader_config))
case "db":
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
case _:
raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document)
return documents
return documents
+21 -32
View File
@@ -2,17 +2,14 @@ import logging
from typing import Any, List, Optional
from llama_index.core import SQLDatabase, Document
from llama_index.core.objects import SQLTableSchema
from llama_index.core.readers.base import BaseReader
from llama_index.readers.database import DatabaseReader
from pydantic import BaseModel
from sqlalchemy import create_engine
from sqlalchemy import text
from sqlalchemy import create_engine, text
from sqlalchemy.engine import Engine
logger = logging.getLogger(__name__)
class CustomDatabaseReader(BaseReader):
class CustomDatabaseReader(DatabaseReader):
"""Simple Database reader.
Concatenates each row into Document used by LlamaIndex.
@@ -85,19 +82,20 @@ class CustomDatabaseReader(BaseReader):
Returns:
List[Document]: A list of Document objects.
"""
dco_str = ""
dco_str = ""
with self.sql_database.engine.connect() as connection:
if query is None:
raise ValueError("A query parameter is necessary to filter the data")
else:
result = connection.execute(text(query))
dco_str = ", ".join(
dco_str += ", ".join(
[f"{entry}" for entry in result.keys()]
)
) + "\n"
for item in result.fetchall():
# fetch each item
# Fetch each item
record_str = ", ".join(
[f"{entry}" for col, entry in zip(result.keys(), item)]
)
@@ -111,45 +109,36 @@ class CustomDatabaseReader(BaseReader):
class DBLoaderConfig(BaseModel):
uri: str
queries: List[str]
queries: List[dict]
def get_db_documents(configs: list[DBLoaderConfig]):
def get_db_documents(configs: List[DBLoaderConfig]) -> List[Document]:
docs = []
if len(configs) == 0 or configs[0].uri == "":
if not configs or not configs[0].uri:
logger.warning(
f"Failed to load database, error message: uri is empty. Return as empty document list."
)
return docs
metadata = {
#'file_name':'',
'file_type':'application/booway.document.zj',
#'file_path':'',
#'file_size':'',
#'creation_date':'',
#'last_modified_date':'',
'file_type': 'application/booway.document.zj',
}
#from llama_index.readers.database import DatabaseReader
for entry in configs:
engine = create_engine(entry.uri)
sql_database = SQLDatabase(engine)
# table_schema_objs = makeDescriptionByEngine(sql_database)
# table_node_mapping = SQLTableNodeMapping(sql_database)
#
# nodes = table_node_mapping.to_nodes(table_schema_objs)
# for node in nodes:
# node.metadata.update(metadata)
#
# docs.extend(nodes)
queries = entry.queries or []
loader = CustomDatabaseReader(sql_database)
for query in queries:
for query_dict in entry.queries:
query = query_dict.get("sql", "")
explanation = query_dict.get("explanation", "")
logger.info(f"Loading data from database with query: {query}")
documents = loader.load_data(query=query)
docs.extend(documents)
return docs
# 添加解释到元数据中
for doc in documents:
doc.metadata["explanation"] = explanation
doc.metadata.update(metadata) # 更新或添加额外的元数据
docs.append(doc)
return docs