Files
zjdataai-app/backend/app/engine/tools/__init__.py
T
2024-08-30 10:49:05 +08:00

60 lines
2.3 KiB
Python

import importlib
import os
import yaml
from llama_index.core.tools.function_tool import FunctionTool
from llama_index.core.tools.tool_spec.base import BaseToolSpec
class ToolType:
LLAMAHUB = "llamahub"
LOCAL = "local"
class ToolFactory:
TOOL_SOURCE_PACKAGE_MAP = {
ToolType.LLAMAHUB: "llama_index.tools",
ToolType.LOCAL: "app.engine.tools",
}
def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]:
source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type]
try:
if "ToolSpec" in tool_name:
tool_package, tool_cls_name = tool_name.split(".")
module_name = f"{source_package}.{tool_package}"
module = importlib.import_module(module_name)
tool_class = getattr(module, tool_cls_name)
tool_spec: BaseToolSpec = tool_class(**config)
return tool_spec.to_tool_list()
else:
module = importlib.import_module(f"{source_package}.{tool_name}")
tools = module.get_tools(**config)
if not all(isinstance(tool, FunctionTool) for tool in tools):
raise ValueError(
f"The module {module} does not contain valid tools"
)
return tools
except ImportError as e:
raise ValueError(f"Failed to import tool {tool_name}: {e}")
except AttributeError as e:
raise ValueError(f"Failed to load tool {tool_name}: {e}")
@staticmethod
def from_env() -> list[FunctionTool]:
tools = []
if os.path.exists("config/tools.yaml"):
with open("config/tools.yaml", "r", encoding='UTF-8') as f:
tool_configs = yaml.safe_load(f)
if tool_configs != None and len(tool_configs.items()) != 0:
for tool_type, config_entries in tool_configs.items():
if config_entries == None or len(config_entries.items()) == 0:
continue
for tool_name, config in config_entries.items():
tools.extend(
ToolFactory.load_tools(tool_type, tool_name, config)
)
return tools