Skip to content

Commit

Permalink
Merge pull request #19 from ericmjl:imagebot
Browse files Browse the repository at this point in the history
feat: Added ImageBot to generate images
  • Loading branch information
ericmjl authored Nov 11, 2023
2 parents 0385994 + 16ee108 commit 4337998
Show file tree
Hide file tree
Showing 6 changed files with 229 additions and 3 deletions.
65 changes: 65 additions & 0 deletions docs/examples/imagebot.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ImageBot\n",
"\n",
"This notebook shows how to use the ImageBot API to generate images from text.\n",
"Underneath the hood, it uses the OpenAI API.\n",
"This bot can be combined with other bots (e.g. `SimpleBot`) to create rich content."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from llamabot.bot.imagebot import ImageBot\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"bot = ImageBot()\n",
"bot(\"A siamese cat.\")\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "llamabot",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 2 additions & 2 deletions llamabot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,10 @@
# Ensure that ~/.llamabotrc exists.
from pathlib import Path

from .bot import ChatBot, QueryBot, SimpleBot
from .bot import ChatBot, QueryBot, SimpleBot, ImageBot
from .recorder import PromptRecorder

__all__ = ["ChatBot", "SimpleBot", "QueryBot", "PromptRecorder"]
__all__ = ["ChatBot", "ImageBot", "SimpleBot", "QueryBot", "PromptRecorder"]


(Path.home() / ".llamabot").mkdir(parents=True, exist_ok=True)
3 changes: 2 additions & 1 deletion llamabot/bot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .chatbot import ChatBot
from .querybot import QueryBot
from .simplebot import SimpleBot
from .imagebot import ImageBot

pn.extension()
load_dotenv()
Expand All @@ -28,4 +29,4 @@
openai.api_key = api_key


__all__ = ["SimpleBot", "ChatBot", "QueryBot"]
__all__ = ["SimpleBot", "ChatBot", "QueryBot", "ImageBot"]
77 changes: 77 additions & 0 deletions llamabot/bot/imagebot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""ImageBot module for generating images."""

from openai import OpenAI
from IPython.display import display, Image
import requests
from pathlib import Path
from typing import Union


class ImageBot:
"""ImageBot class for generating images.
:param model: The model to use. Defaults to "dall-e-3".
:param size: The size of the image to generate. Defaults to "1024x1024".
:param quality: The quality of the image to generate. Defaults to "standard".
:param n: The number of images to generate. Defaults to 1.
"""

def __init__(self, model="dall-e-3", size="1024x1024", quality="hd", n=1):
self.client = OpenAI()
self.model = model
self.size = size
self.quality = quality
self.n = n

def __call__(self, prompt: str, save_path: Path = None) -> Union[str, Path]:
"""Generate an image from a prompt.
:param prompt: The prompt to generate an image from.
:param save_path: The path to save the generated image to.
:return: The URL of the generated image if running in a Jupyter notebook (str),
otherwise a pathlib.Path object pointing to the generated image.
"""
response = self.client.images.generate(
model=self.model,
prompt=prompt,
size=self.size,
quality=self.quality,
n=self.n,
)
image_url = response.data[0].url

# Check if running in a Jupyter notebook
if is_running_in_jupyter():
display(Image(url=image_url))
return image_url

image_data = requests.get(image_url).content

from llamabot import SimpleBot

bot = SimpleBot(
"You are a helpful filenaming assistant. "
"Filenames should use underscores instead of spaces, "
"and should be all lowercase. "
"Exclude the file extension. "
"Give me a compact filename for the following prompt:"
)
response = bot(prompt)
filename = response.message
if not save_path:
save_path = Path(f"{filename}.jpg")
with open(save_path, "wb") as file:
file.write(image_data)
return save_path


def is_running_in_jupyter() -> bool:
"""Check if running in a Jupyter notebook.
:return: True if running in a Jupyter notebook, otherwise False.
"""
try:
get_ipython()
return True
except NameError:
return False
82 changes: 82 additions & 0 deletions tests/bot/test_imagebot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""Test the ImageBot class."""
from llamabot import ImageBot, SimpleBot
import requests


def test_initialization_defaults():
"""Test the initialization of the ImageBot class with default parameters."""
bot = ImageBot()
assert bot.model == "dall-e-3"
assert bot.size == "1024x1024"
assert bot.quality == "hd"
assert bot.n == 1


def test_initialization_custom():
"""Test the initialization of the ImageBot class with custom parameters."""
bot = ImageBot(model="custom-model", size="800x800", quality="standard", n=2)
assert bot.model == "custom-model"
assert bot.size == "800x800"
assert bot.quality == "standard"
assert bot.n == 2


def test_call_in_jupyter(mocker):
"""Test the call method when running in a Jupyter notebook.
This test tests that the call method returns the URL of the generated image
when running in a Jupyter notebook and displays the image.
:param mocker: The pytest-mock fixture.
"""
mocker.patch("llamabot.bot.imagebot.is_running_in_jupyter", return_value=True)
mock_display = mocker.patch("llamabot.bot.imagebot.display")

bot = ImageBot()

# Mock the client's generate method on the instance to return the desired URL
mock_response = mocker.MagicMock()
mock_response.data = [mocker.MagicMock(url="http://image.url")]
bot.client = mocker.MagicMock()
bot.client.images.generate.return_value = mock_response

result = bot("test prompt")
assert result == "http://image.url"
mock_display.assert_called_once()


def test_call_outside_jupyter(mocker, tmp_path):
"""Test the call method when not running in a Jupyter notebook.
This test tests that the call method returns the path to the generated image
when not running in a Jupyter notebook.
:param mocker: The pytest-mock fixture.
:param tmp_path: The pytest tmp_path fixture.
"""
# Mock the is_running_in_jupyter method
mocker.patch("llamabot.bot.imagebot.is_running_in_jupyter", return_value=False)

# Instantiate ImageBot
bot = ImageBot()

# Mock the client's generate method on the instance to return the desired URL
mock_response = mocker.MagicMock()
mock_response.data = [mocker.MagicMock(url="http://image.url")]
bot.client = mocker.MagicMock()
bot.client.images.generate.return_value = mock_response

# Mock requests.get to return a mock response with content
mock_get_response = mocker.MagicMock()
mock_get_response.content = b"image_data"
mocker.patch("requests.get", return_value=mock_get_response)

# Mock the SimpleBot's __call__ method
mocker.patch.object(
SimpleBot, "__call__", return_value=mocker.MagicMock(message="test_prompt")
)

# Call the method and perform the assertion
result = bot("test prompt", tmp_path / "test_prompt.jpg")
assert result == tmp_path / "test_prompt.jpg"
requests.get.assert_called_with("http://image.url")
1 change: 1 addition & 0 deletions tests/bot/test_simplebot.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
temperature=st.floats(min_value=0, max_value=1),
model_name=st.text(),
)
@settings(deadline=None)
def test_simple_bot_init(system_prompt, temperature, model_name):
"""Test that the SimpleBot is initialized correctly.
Expand Down

0 comments on commit 4337998

Please sign in to comment.