初始化提交

This commit is contained in:
2024-08-13 09:37:23 +08:00
parent 4923337038
commit e112fa4e44
50 changed files with 1649 additions and 259 deletions
+6 -3
View File
@@ -17,19 +17,22 @@ def load_configs():
def get_documents():
documents = []
config = load_configs()
if config is None or len(config.items()) == 0:
return documents
for loader_type, loader_config in config.items():
logger.info(
f"Loading documents from loader: {loader_type}, config: {loader_config}"
)
loader_config = loader_config or []
match loader_type:
case "file":
document = get_file_documents(FileLoaderConfig(**loader_config))
case "web":
document = get_web_documents(WebLoaderConfig(**loader_config))
case "db":
document = get_db_documents(
configs=[DBLoaderConfig(**cfg) for cfg in loader_config]
)
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
case _:
raise ValueError(f"Invalid loader type: {loader_type}")
documents.extend(document)
+167 -6
View File
@@ -1,26 +1,187 @@
import os
import logging
from typing import List
from typing import Any, List, Optional
from llama_index.core.readers.base import BaseReader
from llama_index.core.schema import Document
from llama_index.core.utilities.sql_wrapper import SQLDatabase
from sqlalchemy import text
from sqlalchemy.engine import Engine
from llama_index.core import SQLDatabase, Document
from llama_index.core.objects import SQLTableSchema, SQLTableNodeMapping
from llama_index.core.readers.base import BaseReader
from llama_index.readers.database import DatabaseReader
from pydantic import BaseModel, validator
from llama_index.core.indices.vector_store import VectorStoreIndex
from sqlalchemy import create_engine
logger = logging.getLogger(__name__)
class CustomDatabaseReader(BaseReader):
"""Simple Database reader.
Concatenates each row into Document used by LlamaIndex.
Args:
sql_database (Optional[SQLDatabase]): SQL database to use,
including table names to specify.
See :ref:`Ref-Struct-Store` for more details.
OR
engine (Optional[Engine]): SQLAlchemy Engine object of the database connection.
OR
uri (Optional[str]): uri of the database connection.
OR
scheme (Optional[str]): scheme of the database connection.
host (Optional[str]): host of the database connection.
port (Optional[int]): port of the database connection.
user (Optional[str]): user of the database connection.
password (Optional[str]): password of the database connection.
dbname (Optional[str]): dbname of the database connection.
Returns:
DatabaseReader: A DatabaseReader object.
"""
def __init__(
self,
sql_database: Optional[SQLDatabase] = None,
engine: Optional[Engine] = None,
uri: Optional[str] = None,
scheme: Optional[str] = None,
host: Optional[str] = None,
port: Optional[str] = None,
user: Optional[str] = None,
password: Optional[str] = None,
dbname: Optional[str] = None,
*args: Any,
**kwargs: Any,
) -> None:
"""Initialize with parameters."""
if sql_database:
self.sql_database = sql_database
elif engine:
self.sql_database = SQLDatabase(engine, *args, **kwargs)
elif uri:
self.uri = uri
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
elif scheme and host and port and user and password and dbname:
uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
self.uri = uri
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
else:
raise ValueError(
"You must provide either a SQLDatabase, "
"a SQL Alchemy Engine, a valid connection URI, or a valid "
"set of credentials."
)
def load_data(self, query: str) -> List[Document]:
"""Query and load data from the Database, returning a list of Documents.
Args:
query (str): Query parameter to filter tables and rows.
Returns:
List[Document]: A list of Document objects.
"""
dco_str = ""
with self.sql_database.engine.connect() as connection:
if query is None:
raise ValueError("A query parameter is necessary to filter the data")
else:
result = connection.execute(text(query))
dco_str = ", ".join(
[f"{entry}" for entry in result.keys()]
)
for item in result.fetchall():
# fetch each item
record_str = ", ".join(
[f"{entry}" for col, entry in zip(result.keys(), item)]
)
dco_str += record_str + "\n"
doc = Document(text=dco_str)
doc.metadata["name"] = query
doc.metadata["context"] = query
doc.metadata["file_type"] = "application/vnd.ms-excel"
return [doc]
class DBLoaderConfig(BaseModel):
uri: str
queries: List[str]
def makeDescriptionByEngine(sql_database:SQLDatabase):
reader = DatabaseReader(sql_database)
table_names = sql_database.get_usable_table_names()
table_schema_objs = []
for table_name in table_names:
columns = sql_database.get_table_columns(table_name)
if len(columns) > 150:
continue
stats_txt = ""
if table_name == 'gongchengshuxing':
stats_txt = '该表中有以下属性:'
documents = reader.load_data(query='select name from gongchengshuxing')
for index in range(len(documents) if len(documents) < 30 else 30):
if index == 0:
continue
elif index > 1:
stats_txt += ','
stats_txt += documents[index].text.split(':')[1]
tbSchema = (SQLTableSchema(table_name=table_name, context_str=stats_txt))
table_schema_objs.append(tbSchema)
return table_schema_objs
def get_db_documents(configs: list[DBLoaderConfig]):
from llama_index.readers.database import DatabaseReader
docs = []
if len(configs) == 0 or configs[0].uri == "":
logger.warning(
f"Failed to load database, error message: uri is empty. Return as empty document list."
)
return docs
metadata = {
#'file_name':'',
'file_type':'application/booway.document.zj',
#'file_path':'',
#'file_size':'',
#'creation_date':'',
#'last_modified_date':'',
}
#from llama_index.readers.database import DatabaseReader
for entry in configs:
loader = DatabaseReader(uri=entry.uri)
for query in entry.queries:
engine = create_engine(entry.uri)
sql_database = SQLDatabase(engine)
table_schema_objs = makeDescriptionByEngine(sql_database)
table_node_mapping = SQLTableNodeMapping(sql_database)
nodes = table_node_mapping.to_nodes(table_schema_objs)
for node in nodes:
node.metadata.update(metadata)
docs.extend(nodes)
queries = entry.queries or []
loader = CustomDatabaseReader(sql_database)
for query in queries:
logger.info(f"Loading data from database with query: {query}")
documents = loader.load_data(query=query)
docs.extend(documents)
return documents
docs.extend(documents)
return docs
+9
View File
@@ -1,6 +1,9 @@
import os
import logging
from typing import Dict
from llama_index.core.readers.base import BaseReader
from llama_index.core.readers.json import JSONReader
from llama_parse import LlamaParse
from pydantic import BaseModel, validator
@@ -39,6 +42,9 @@ def llama_parse_extractor() -> Dict[str, LlamaParse]:
parser = llama_parse_parser()
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
def llama_local_extractor() -> Dict[str, BaseReader]:
return {"json" : JSONReader}
def get_file_documents(config: FileLoaderConfig):
from llama_index.core.readers import SimpleDirectoryReader
@@ -53,6 +59,9 @@ def get_file_documents(config: FileLoaderConfig):
nest_asyncio.apply()
file_extractor = llama_parse_extractor()
else:
file_extractor = llama_local_extractor()
reader = SimpleDirectoryReader(
config.data_dir,
recursive=True,
+2 -1
View File
@@ -11,7 +11,7 @@ class CrawlUrl(BaseModel):
class WebLoaderConfig(BaseModel):
driver_arguments: list[str] = Field(default=None)
urls: list[CrawlUrl]
urls: list[CrawlUrl] = []
def get_web_documents(config: WebLoaderConfig):
@@ -25,6 +25,7 @@ def get_web_documents(config: WebLoaderConfig):
options.add_argument(arg)
docs = []
urls = config.urls or []
for url in config.urls:
scraper = WholeSiteReader(
prefix=url.prefix,