89 lines
2.5 KiB
Python
89 lines
2.5 KiB
Python
import os
|
|
import logging
|
|
from typing import Dict
|
|
|
|
from llama_index.core.readers.base import BaseReader
|
|
from llama_index.core.readers.json import JSONReader
|
|
from llama_parse import LlamaParse
|
|
from pydantic import BaseModel, validator
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FileLoaderConfig(BaseModel):
|
|
data_dir: str = "data"
|
|
use_llama_parse: bool = False
|
|
|
|
@validator("data_dir")
|
|
def data_dir_must_exist(cls, v):
|
|
if not os.path.isdir(v):
|
|
raise ValueError(f"Directory '{v}' does not exist")
|
|
return v
|
|
|
|
|
|
def llama_parse_parser():
|
|
if os.getenv("LLAMA_CLOUD_API_KEY") is None:
|
|
raise ValueError(
|
|
"LLAMA_CLOUD_API_KEY environment variable is not set. "
|
|
"Please set it in .env file or in your shell environment then run again!"
|
|
)
|
|
parser = LlamaParse(
|
|
result_type="markdown",
|
|
verbose=True,
|
|
language="en",
|
|
ignore_errors=False,
|
|
)
|
|
return parser
|
|
|
|
|
|
def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
|
from llama_parse.utils import SUPPORTED_FILE_TYPES
|
|
|
|
parser = llama_parse_parser()
|
|
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
|
|
|
def llama_local_extractor() -> Dict[str, BaseReader]:
|
|
return {"json" : JSONReader}
|
|
|
|
|
|
def get_file_documents(config: FileLoaderConfig):
|
|
from llama_index.core.readers import SimpleDirectoryReader
|
|
|
|
try:
|
|
file_extractor = None
|
|
if config.use_llama_parse:
|
|
# LlamaParse is async first,
|
|
# so we need to use nest_asyncio to run it in sync mode
|
|
import nest_asyncio
|
|
|
|
nest_asyncio.apply()
|
|
|
|
file_extractor = llama_parse_extractor()
|
|
else:
|
|
file_extractor = llama_local_extractor()
|
|
|
|
reader = SimpleDirectoryReader(
|
|
config.data_dir,
|
|
recursive=True,
|
|
filename_as_id=True,
|
|
raise_on_error=True,
|
|
file_extractor=file_extractor,
|
|
)
|
|
return reader.load_data()
|
|
except Exception as e:
|
|
import sys
|
|
import traceback
|
|
|
|
# Catch the error if the data dir is empty
|
|
# and return as empty document list
|
|
_, _, exc_traceback = sys.exc_info()
|
|
function_name = traceback.extract_tb(exc_traceback)[-1].name
|
|
if function_name == "_add_files":
|
|
logger.warning(
|
|
f"Failed to load file documents, error message: {e} . Return as empty document list."
|
|
)
|
|
return []
|
|
else:
|
|
# Raise the error if it is not the case of empty data dir
|
|
raise e
|