优化了提示词
This commit is contained in:
@@ -0,0 +1,40 @@
|
||||
import logging
|
||||
|
||||
import yaml
|
||||
from app.engine.loaders.db import DBLoaderConfig, get_db_documents
|
||||
from app.engine.loaders.file import FileLoaderConfig, get_file_documents
|
||||
from app.engine.loaders.web import WebLoaderConfig, get_web_documents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_configs():
|
||||
with open("config/loaders.yaml") as f:
|
||||
configs = yaml.safe_load(f)
|
||||
return configs
|
||||
|
||||
|
||||
def get_documents():
|
||||
documents = []
|
||||
config = load_configs()
|
||||
if config is None or len(config.items()) == 0:
|
||||
return documents
|
||||
|
||||
for loader_type, loader_config in config.items():
|
||||
logger.info(
|
||||
f"Loading documents from loader: {loader_type}, config: {loader_config}"
|
||||
)
|
||||
|
||||
loader_config = loader_config or []
|
||||
match loader_type:
|
||||
case "file":
|
||||
document = get_file_documents(FileLoaderConfig(**loader_config))
|
||||
case "web":
|
||||
document = get_web_documents(WebLoaderConfig(**loader_config))
|
||||
case "db":
|
||||
document = get_db_documents(configs=[DBLoaderConfig(**cfg) for cfg in loader_config])
|
||||
case _:
|
||||
raise ValueError(f"Invalid loader type: {loader_type}")
|
||||
documents.extend(document)
|
||||
|
||||
return documents
|
||||
@@ -0,0 +1,140 @@
|
||||
import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_index.core import SQLDatabase, Document
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CustomDatabaseReader(DatabaseReader):
|
||||
"""Simple Database reader.
|
||||
|
||||
Concatenates each row into Document used by LlamaIndex.
|
||||
|
||||
Args:
|
||||
sql_database (Optional[SQLDatabase]): SQL database to use,
|
||||
including table names to specify.
|
||||
See :ref:`Ref-Struct-Store` for more details.
|
||||
|
||||
OR
|
||||
|
||||
engine (Optional[Engine]): SQLAlchemy Engine object of the database connection.
|
||||
|
||||
OR
|
||||
|
||||
uri (Optional[str]): uri of the database connection.
|
||||
|
||||
OR
|
||||
|
||||
scheme (Optional[str]): scheme of the database connection.
|
||||
host (Optional[str]): host of the database connection.
|
||||
port (Optional[int]): port of the database connection.
|
||||
user (Optional[str]): user of the database connection.
|
||||
password (Optional[str]): password of the database connection.
|
||||
dbname (Optional[str]): dbname of the database connection.
|
||||
|
||||
Returns:
|
||||
DatabaseReader: A DatabaseReader object.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sql_database: Optional[SQLDatabase] = None,
|
||||
engine: Optional[Engine] = None,
|
||||
uri: Optional[str] = None,
|
||||
scheme: Optional[str] = None,
|
||||
host: Optional[str] = None,
|
||||
port: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
dbname: Optional[str] = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize with parameters."""
|
||||
if sql_database:
|
||||
self.sql_database = sql_database
|
||||
elif engine:
|
||||
self.sql_database = SQLDatabase(engine, *args, **kwargs)
|
||||
elif uri:
|
||||
self.uri = uri
|
||||
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||||
elif scheme and host and port and user and password and dbname:
|
||||
uri = f"{scheme}://{user}:{password}@{host}:{port}/{dbname}"
|
||||
self.uri = uri
|
||||
self.sql_database = SQLDatabase.from_uri(uri, *args, **kwargs)
|
||||
else:
|
||||
raise ValueError(
|
||||
"You must provide either a SQLDatabase, "
|
||||
"a SQL Alchemy Engine, a valid connection URI, or a valid "
|
||||
"set of credentials."
|
||||
)
|
||||
|
||||
def load_data(self, query: str, explanation: str) -> List[Document]:
|
||||
"""Query and load data from the Database, returning a list of Documents.
|
||||
|
||||
Args:
|
||||
query (str): Query parameter to filter tables and rows.
|
||||
explanation (str): Explanation for the query to be included in the document.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
dco_str = explanation + "\n"
|
||||
|
||||
with self.sql_database.engine.connect() as connection:
|
||||
if query is None:
|
||||
raise ValueError("A query parameter is necessary to filter the data")
|
||||
else:
|
||||
result = connection.execute(text(query))
|
||||
|
||||
dco_str += ", ".join(
|
||||
[f"{entry}" for entry in result.keys()]
|
||||
) + "\n"
|
||||
|
||||
for item in result.fetchall():
|
||||
# Fetch each item
|
||||
record_str = ", ".join(
|
||||
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
||||
)
|
||||
dco_str += record_str + "\n"
|
||||
|
||||
doc = Document(text=dco_str)
|
||||
doc.metadata["name"] = query
|
||||
doc.metadata["context"] = query
|
||||
doc.metadata["file_type"] = "application/vnd.ms-excel"
|
||||
return [doc]
|
||||
|
||||
class DBLoaderConfig(BaseModel):
|
||||
uri: str
|
||||
queries: List[dict]
|
||||
|
||||
def get_db_documents(configs: list[DBLoaderConfig]):
|
||||
docs = []
|
||||
|
||||
if len(configs) == 0 or configs[0].uri == "":
|
||||
logger.warning(
|
||||
f"Failed to load database, error message: uri is empty. Return as empty document list."
|
||||
)
|
||||
return docs
|
||||
|
||||
metadata = {
|
||||
'file_type': 'application/booway.document.zj',
|
||||
}
|
||||
|
||||
for entry in configs:
|
||||
engine = create_engine(entry.uri)
|
||||
sql_database = SQLDatabase(engine)
|
||||
|
||||
loader = CustomDatabaseReader(sql_database)
|
||||
for query_dict in entry.queries:
|
||||
query = query_dict.get("sql", "")
|
||||
explanation = query_dict.get("explanation", "")
|
||||
logger.info(f"Loading data from database with query: {query}")
|
||||
documents = loader.load_data(query=query, explanation=explanation)
|
||||
|
||||
docs.extend(documents)
|
||||
return docs
|
||||
@@ -0,0 +1,88 @@
|
||||
import os
|
||||
import logging
|
||||
from typing import Dict
|
||||
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.core.readers.json import JSONReader
|
||||
from llama_parse import LlamaParse
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileLoaderConfig(BaseModel):
|
||||
data_dir: str = "data"
|
||||
use_llama_parse: bool = False
|
||||
|
||||
@validator("data_dir")
|
||||
def data_dir_must_exist(cls, v):
|
||||
if not os.path.isdir(v):
|
||||
raise ValueError(f"Directory '{v}' does not exist")
|
||||
return v
|
||||
|
||||
|
||||
def llama_parse_parser():
|
||||
if os.getenv("LLAMA_CLOUD_API_KEY") is None:
|
||||
raise ValueError(
|
||||
"LLAMA_CLOUD_API_KEY environment variable is not set. "
|
||||
"Please set it in .env file or in your shell environment then run again!"
|
||||
)
|
||||
parser = LlamaParse(
|
||||
result_type="markdown",
|
||||
verbose=True,
|
||||
language="en",
|
||||
ignore_errors=False,
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
def llama_parse_extractor() -> Dict[str, LlamaParse]:
|
||||
from llama_parse.utils import SUPPORTED_FILE_TYPES
|
||||
|
||||
parser = llama_parse_parser()
|
||||
return {file_type: parser for file_type in SUPPORTED_FILE_TYPES}
|
||||
|
||||
def llama_local_extractor() -> Dict[str, BaseReader]:
|
||||
return {".json" : JSONReader(clean_json=False,levels_back=0)}
|
||||
|
||||
|
||||
def get_file_documents(config: FileLoaderConfig):
|
||||
from llama_index.core.readers import SimpleDirectoryReader
|
||||
|
||||
try:
|
||||
file_extractor = None
|
||||
if config.use_llama_parse:
|
||||
# LlamaParse is async first,
|
||||
# so we need to use nest_asyncio to run it in sync mode
|
||||
import nest_asyncio
|
||||
|
||||
nest_asyncio.apply()
|
||||
|
||||
file_extractor = llama_parse_extractor()
|
||||
else:
|
||||
file_extractor = llama_local_extractor()
|
||||
|
||||
reader = SimpleDirectoryReader(
|
||||
config.data_dir,
|
||||
recursive=True,
|
||||
filename_as_id=True,
|
||||
raise_on_error=True,
|
||||
file_extractor=file_extractor,
|
||||
)
|
||||
return reader.load_data()
|
||||
except Exception as e:
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
# Catch the error if the data dir is empty
|
||||
# and return as empty document list
|
||||
_, _, exc_traceback = sys.exc_info()
|
||||
function_name = traceback.extract_tb(exc_traceback)[-1].name
|
||||
if function_name == "_add_files":
|
||||
logger.warning(
|
||||
f"Failed to load file documents, error message: {e} . Return as empty document list."
|
||||
)
|
||||
return []
|
||||
else:
|
||||
# Raise the error if it is not the case of empty data dir
|
||||
raise e
|
||||
@@ -0,0 +1,37 @@
|
||||
import os
|
||||
import json
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CrawlUrl(BaseModel):
|
||||
base_url: str
|
||||
prefix: str
|
||||
max_depth: int = Field(default=1, ge=0)
|
||||
|
||||
|
||||
class WebLoaderConfig(BaseModel):
|
||||
driver_arguments: list[str] = Field(default=None)
|
||||
urls: list[CrawlUrl] = []
|
||||
|
||||
|
||||
def get_web_documents(config: WebLoaderConfig):
|
||||
from llama_index.readers.web import WholeSiteReader
|
||||
from selenium import webdriver
|
||||
from selenium.webdriver.chrome.options import Options
|
||||
|
||||
options = Options()
|
||||
driver_arguments = config.driver_arguments or []
|
||||
for arg in driver_arguments:
|
||||
options.add_argument(arg)
|
||||
|
||||
docs = []
|
||||
urls = config.urls or []
|
||||
for url in config.urls:
|
||||
scraper = WholeSiteReader(
|
||||
prefix=url.prefix,
|
||||
max_depth=url.max_depth,
|
||||
driver=webdriver.Chrome(options=options),
|
||||
)
|
||||
docs.extend(scraper.load_data(url.base_url))
|
||||
|
||||
return docs
|
||||
@@ -9,7 +9,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def load_configs():
|
||||
with open("config/loaders.yaml",'r', encoding='utf-8') as f:
|
||||
with open("config/loaders.yaml") as f:
|
||||
configs = yaml.safe_load(f)
|
||||
return configs
|
||||
|
||||
|
||||
@@ -2,14 +2,17 @@ import logging
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_index.core import SQLDatabase, Document
|
||||
from llama_index.core.objects import SQLTableSchema
|
||||
from llama_index.core.readers.base import BaseReader
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class CustomDatabaseReader(DatabaseReader):
|
||||
class CustomDatabaseReader(BaseReader):
|
||||
"""Simple Database reader.
|
||||
|
||||
Concatenates each row into Document used by LlamaIndex.
|
||||
@@ -73,30 +76,28 @@ class CustomDatabaseReader(DatabaseReader):
|
||||
"set of credentials."
|
||||
)
|
||||
|
||||
def load_data(self, query: str, explanation: str) -> List[Document]:
|
||||
def load_data(self, query: str) -> List[Document]:
|
||||
"""Query and load data from the Database, returning a list of Documents.
|
||||
|
||||
Args:
|
||||
query (str): Query parameter to filter tables and rows.
|
||||
explanation (str): Explanation for the query to be included in the document.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects.
|
||||
"""
|
||||
dco_str = explanation + "\n"
|
||||
|
||||
dco_str = ""
|
||||
with self.sql_database.engine.connect() as connection:
|
||||
if query is None:
|
||||
raise ValueError("A query parameter is necessary to filter the data")
|
||||
else:
|
||||
result = connection.execute(text(query))
|
||||
|
||||
dco_str += ", ".join(
|
||||
dco_str = ", ".join(
|
||||
[f"{entry}" for entry in result.keys()]
|
||||
) + "\n"
|
||||
)
|
||||
|
||||
for item in result.fetchall():
|
||||
# Fetch each item
|
||||
# fetch each item
|
||||
record_str = ", ".join(
|
||||
[f"{entry}" for col, entry in zip(result.keys(), item)]
|
||||
)
|
||||
@@ -110,7 +111,7 @@ class CustomDatabaseReader(DatabaseReader):
|
||||
|
||||
class DBLoaderConfig(BaseModel):
|
||||
uri: str
|
||||
queries: List[dict]
|
||||
queries: List[str]
|
||||
|
||||
def get_db_documents(configs: list[DBLoaderConfig]):
|
||||
docs = []
|
||||
@@ -122,19 +123,33 @@ def get_db_documents(configs: list[DBLoaderConfig]):
|
||||
return docs
|
||||
|
||||
metadata = {
|
||||
'file_type': 'application/booway.document.zj',
|
||||
#'file_name':'',
|
||||
'file_type':'application/booway.document.zj',
|
||||
#'file_path':'',
|
||||
#'file_size':'',
|
||||
#'creation_date':'',
|
||||
#'last_modified_date':'',
|
||||
}
|
||||
|
||||
#from llama_index.readers.database import DatabaseReader
|
||||
for entry in configs:
|
||||
engine = create_engine(entry.uri)
|
||||
sql_database = SQLDatabase(engine)
|
||||
|
||||
# table_schema_objs = makeDescriptionByEngine(sql_database)
|
||||
# table_node_mapping = SQLTableNodeMapping(sql_database)
|
||||
#
|
||||
# nodes = table_node_mapping.to_nodes(table_schema_objs)
|
||||
# for node in nodes:
|
||||
# node.metadata.update(metadata)
|
||||
#
|
||||
# docs.extend(nodes)
|
||||
|
||||
queries = entry.queries or []
|
||||
loader = CustomDatabaseReader(sql_database)
|
||||
for query_dict in entry.queries:
|
||||
query = query_dict.get("sql", "")
|
||||
explanation = query_dict.get("explanation", "")
|
||||
for query in queries:
|
||||
logger.info(f"Loading data from database with query: {query}")
|
||||
documents = loader.load_data(query=query, explanation=explanation)
|
||||
documents = loader.load_data(query=query)
|
||||
|
||||
docs.extend(documents)
|
||||
return docs
|
||||
|
||||
Reference in New Issue
Block a user