109 lines
3.4 KiB
Python
109 lines
3.4 KiB
Python
import os
|
|
import uuid
|
|
import logging
|
|
import requests
|
|
from typing import Optional
|
|
from pydantic import BaseModel, Field
|
|
from llama_index.core.tools import FunctionTool
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ImageGeneratorToolOutput(BaseModel):
|
|
is_success: bool = Field(
|
|
...,
|
|
description="Whether the image generation was successful.",
|
|
)
|
|
image_url: Optional[str] = Field(
|
|
None,
|
|
description="The URL of the generated image.",
|
|
)
|
|
error_message: Optional[str] = Field(
|
|
None,
|
|
description="The error message if the image generation failed.",
|
|
)
|
|
|
|
|
|
class ImageGeneratorTool:
|
|
_IMG_OUTPUT_FORMAT = "webp"
|
|
_IMG_OUTPUT_DIR = "output/tool"
|
|
_IMG_GEN_API = "https://api.stability.ai/v2beta/stable-image/generate/core"
|
|
|
|
def __init__(self, api_key: str = None):
|
|
if not api_key:
|
|
api_key = os.getenv("STABILITY_API_KEY")
|
|
self._api_key = api_key
|
|
self.fileserver_url_prefix = os.getenv("FILESERVER_URL_PREFIX")
|
|
if self._api_key is None:
|
|
raise ValueError(
|
|
"STABILITY_API_KEY key is required to run image generator. Get it here: https://platform.stability.ai/account/keys"
|
|
)
|
|
if self.fileserver_url_prefix is None:
|
|
raise ValueError("FILESERVER_URL_PREFIX is required.")
|
|
|
|
def _prepare_output_dir(self):
|
|
"""
|
|
Create the output directory if it doesn't exist
|
|
"""
|
|
if not os.path.exists(self._IMG_OUTPUT_DIR):
|
|
os.makedirs(self._IMG_OUTPUT_DIR, exist_ok=True)
|
|
|
|
def _save_image(self, image_data: bytes):
|
|
self._prepare_output_dir()
|
|
filename = f"{uuid.uuid4()}.{self._IMG_OUTPUT_FORMAT}"
|
|
output_path = os.path.join(self._IMG_OUTPUT_DIR, filename)
|
|
with open(output_path, "wb") as f:
|
|
f.write(image_data)
|
|
url = f"{os.getenv('FILESERVER_URL_PREFIX')}/{self._IMG_OUTPUT_DIR}/{filename}"
|
|
logger.info(f"Saved image to {output_path}.\nURL: {url}")
|
|
return url
|
|
|
|
def _call_stability_api(self, prompt: str):
|
|
headers = {
|
|
"authorization": f"Bearer {self._api_key}",
|
|
"accept": "image/*",
|
|
}
|
|
data = {
|
|
"prompt": prompt,
|
|
"output_format": self._IMG_OUTPUT_FORMAT,
|
|
}
|
|
|
|
response = requests.post(
|
|
self._IMG_GEN_API,
|
|
headers=headers,
|
|
files={"none": ""},
|
|
data=data,
|
|
)
|
|
response.raise_for_status()
|
|
|
|
return response
|
|
|
|
def generate_image(self, prompt: str) -> ImageGeneratorToolOutput:
|
|
"""
|
|
Use this tool to generate an image based on the prompt.
|
|
Args:
|
|
prompt (str): The prompt to generate the image from.
|
|
"""
|
|
|
|
try:
|
|
# Call the Stability API
|
|
response = self._call_stability_api(prompt)
|
|
|
|
# Save the image and get the URL
|
|
image_url = self._save_image(response.content)
|
|
|
|
return ImageGeneratorToolOutput(
|
|
is_success=True,
|
|
image_url=image_url,
|
|
)
|
|
except Exception as e:
|
|
logger.exception(e, exc_info=True)
|
|
return ImageGeneratorToolOutput(
|
|
is_success=False,
|
|
error_message=str(e),
|
|
)
|
|
|
|
|
|
def get_tools(**kwargs):
|
|
return [FunctionTool.from_defaults(ImageGeneratorTool(**kwargs).generate_image)]
|