200 lines
6.6 KiB
Python
200 lines
6.6 KiB
Python
import os
|
|
import logging
|
|
from typing import Dict
|
|
|
|
from llama_index.core.readers.base import BaseReader
|
|
from llama_index.core.readers.json import JSONReader
|
|
from llama_parse import LlamaParse
|
|
from pydantic import BaseModel, validator
|
|
from app.engine.loaders.markdownReader import ChunkMarkdownReader
|
|
from app.engine.loaders.projectJson import ProjectJson
|
|
from typing import Any, Callable, Dict, Generator, List, Optional, Type, Set
|
|
import fsspec,mimetypes
|
|
from fsspec.implementations.local import LocalFileSystem
|
|
from datetime import datetime
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class FileLoaderConfig(BaseModel):
|
|
data_dir: str = "data"
|
|
use_llama_parse: bool = False
|
|
|
|
@validator("data_dir")
|
|
def data_dir_must_exist(cls, v):
|
|
if not os.path.isdir(v):
|
|
raise ValueError(f"Directory '{v}' does not exist")
|
|
return v
|
|
|
|
class CustomFileMetadataFunc:
|
|
"""
|
|
Default file metadata function wrapper which stores the fs.
|
|
Allows for pickling of the function.
|
|
"""
|
|
|
|
def __init__(self, fs: Optional[fsspec.AbstractFileSystem] = None):
|
|
self.fs = fs or self._get_default_fs()
|
|
|
|
def __call__(self, file_path: str) -> Dict:
|
|
return self._default_file_metadata_func(file_path, self.fs)
|
|
|
|
def _default_file_metadata_func(self,
|
|
file_path: str, fs: Optional[fsspec.AbstractFileSystem] = None
|
|
) -> Dict:
|
|
"""
|
|
Get some handy metadata from filesystem.
|
|
|
|
Args:
|
|
file_path: str: file path in str
|
|
"""
|
|
fs = fs or self._get_default_fs()
|
|
stat_result = fs.stat(file_path)
|
|
|
|
try:
|
|
file_name = os.path.basename(str(stat_result["name"]))
|
|
except Exception as e:
|
|
file_name = os.path.basename(file_path)
|
|
|
|
creation_date = self._format_file_timestamp(stat_result.get("created"))
|
|
last_modified_date = self._format_file_timestamp(stat_result.get("mtime"))
|
|
last_accessed_date = self._format_file_timestamp(stat_result.get("atime"))
|
|
default_meta = {
|
|
"file_name": file_name,
|
|
"file_type": mimetypes.guess_type(file_path)[0],
|
|
"file_size": stat_result.get("size"),
|
|
"creation_date": creation_date,
|
|
"last_modified_date": last_modified_date,
|
|
"last_accessed_date": last_accessed_date,
|
|
}
|
|
|
|
# Return not null value
|
|
return {
|
|
meta_key: meta_value
|
|
for meta_key, meta_value in default_meta.items()
|
|
if meta_value is not None
|
|
}
|
|
|
|
def _format_file_timestamp(
|
|
timestamp: float, include_time: bool = False
|
|
) -> Optional[str]:
|
|
"""
|
|
Format file timestamp to a %Y-%m-%d string.
|
|
|
|
Args:
|
|
timestamp (float): timestamp in float
|
|
include_time (bool): whether to include time in the formatted string
|
|
|
|
Returns:
|
|
str: formatted timestamp
|
|
"""
|
|
try:
|
|
if include_time:
|
|
return datetime.utcfromtimestamp(timestamp).strftime("%Y-%m-%dT%H:%M:%SZ")
|
|
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d")
|
|
except Exception:
|
|
return None
|
|
|
|
def _get_default_fs(self) -> fsspec.AbstractFileSystem:
|
|
return LocalFileSystem()
|
|
|
|
def _is_default_fs(self,fs: fsspec.AbstractFileSystem) -> bool:
|
|
return isinstance(fs, LocalFileSystem) and not fs.auto_mkdir
|
|
|
|
|
|
def llama_parse_parser():
|
|
if os.getenv("LLAMA_CLOUD_API_KEY") is None:
|
|
raise ValueError(
|
|
"LLAMA_CLOUD_API_KEY environment variable is not set. "
|
|
"Please set it in .env file or in your shell environment then run again!"
|
|
)
|
|
parser = LlamaParse(
|
|
result_type="markdown",
|
|
verbose=True,
|
|
language="en",
|
|
ignore_errors=False,
|
|
)
|
|
return parser
|
|
|
|
def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
|
from llama_parse.utils import SUPPORTED_FILE_TYPES
|
|
|
|
parser = llama_parse_parser()
|
|
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
|
|
|
def llama_local_extractor() -> Dict[str, BaseReader]:
|
|
parser = {
|
|
".json" : JSONReader(clean_json=False,levels_back=0),
|
|
".md" : ChunkMarkdownReader(),
|
|
}
|
|
return parser
|
|
|
|
def get_file_documents(config: FileLoaderConfig,childPath: str):
|
|
from llama_index.core.readers import SimpleDirectoryReader
|
|
|
|
try:
|
|
file_extractor = None
|
|
if config.use_llama_parse:
|
|
# LlamaParse is async first,
|
|
# so we need to use nest_asyncio to run it in sync mode
|
|
import nest_asyncio
|
|
|
|
nest_asyncio.apply()
|
|
|
|
file_extractor = llama_parse_extractor()
|
|
else:
|
|
file_extractor = llama_local_extractor()
|
|
|
|
reader = SimpleDirectoryReader(
|
|
os.path.join(config.data_dir,childPath.replace('_','\\')),
|
|
recursive=True,
|
|
filename_as_id=True,
|
|
raise_on_error=True,
|
|
file_extractor=file_extractor,
|
|
file_metadata = CustomFileMetadataFunc()
|
|
)
|
|
return reader.load_data()
|
|
except Exception as e:
|
|
import sys
|
|
import traceback
|
|
|
|
# Catch the error if the data dir is empty
|
|
# and return as empty document list
|
|
_, _, exc_traceback = sys.exc_info()
|
|
function_name = traceback.extract_tb(exc_traceback)[-1].name
|
|
if function_name == "_add_files":
|
|
logger.warning(
|
|
f"Failed to load file documents, error message: {e} . Return as empty document list."
|
|
)
|
|
return []
|
|
else:
|
|
# Raise the error if it is not the case of empty data dir
|
|
raise e
|
|
|
|
def prjFileSuffix(dir:str):
|
|
entries = os.listdir(dir)
|
|
file_names = [entry for entry in entries if os.path.isfile(os.path.join(dir, entry))]
|
|
if len(file_names) > 0:
|
|
return os.path.splitext(file_names[0])[1]
|
|
return ''
|
|
|
|
def getProjectName(dir:str):
|
|
suffix = prjFileSuffix(dir)
|
|
if suffix== '.json':
|
|
prjJson = ProjectJson(dir)
|
|
prjJson.parse()
|
|
tb = prjJson.table('工程属性')
|
|
records = tb.records()
|
|
for record in records:
|
|
name = record.value('名称')
|
|
if name == '工程名称':
|
|
return record.value('值')
|
|
elif suffix == '.md':
|
|
md_files = [f for f in os.listdir(dir) if f.endswith('.md')]
|
|
for md_file in md_files:
|
|
prjPath = os.path.join(dir, md_file)
|
|
basename = os.path.splitext(md_file)[0]
|
|
if basename =='工程属性':
|
|
rd = ChunkMarkdownReader()
|
|
rd.load_data(prjPath)
|
|
return rd.findValue("名称=='工程名称'",'值')
|
|
return '' |