-
-
Notifications
You must be signed in to change notification settings - Fork 26
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #19 from ericmjl:imagebot
feat: Added ImageBot to generate images
- Loading branch information
Showing
6 changed files
with
229 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext autoreload\n", | ||
"%autoreload 2\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# ImageBot\n", | ||
"\n", | ||
"This notebook shows how to use the ImageBot API to generate images from text.\n", | ||
"Underneath the hood, it uses the OpenAI API.\n", | ||
"This bot can be combined with other bots (e.g. `SimpleBot`) to create rich content." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from llamabot.bot.imagebot import ImageBot\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"bot = ImageBot()\n", | ||
"bot(\"A siamese cat.\")\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "llamabot", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
"""ImageBot module for generating images.""" | ||
|
||
from openai import OpenAI | ||
from IPython.display import display, Image | ||
import requests | ||
from pathlib import Path | ||
from typing import Union | ||
|
||
|
||
class ImageBot: | ||
"""ImageBot class for generating images. | ||
:param model: The model to use. Defaults to "dall-e-3". | ||
:param size: The size of the image to generate. Defaults to "1024x1024". | ||
:param quality: The quality of the image to generate. Defaults to "standard". | ||
:param n: The number of images to generate. Defaults to 1. | ||
""" | ||
|
||
def __init__(self, model="dall-e-3", size="1024x1024", quality="hd", n=1): | ||
self.client = OpenAI() | ||
self.model = model | ||
self.size = size | ||
self.quality = quality | ||
self.n = n | ||
|
||
def __call__(self, prompt: str, save_path: Path = None) -> Union[str, Path]: | ||
"""Generate an image from a prompt. | ||
:param prompt: The prompt to generate an image from. | ||
:param save_path: The path to save the generated image to. | ||
:return: The URL of the generated image if running in a Jupyter notebook (str), | ||
otherwise a pathlib.Path object pointing to the generated image. | ||
""" | ||
response = self.client.images.generate( | ||
model=self.model, | ||
prompt=prompt, | ||
size=self.size, | ||
quality=self.quality, | ||
n=self.n, | ||
) | ||
image_url = response.data[0].url | ||
|
||
# Check if running in a Jupyter notebook | ||
if is_running_in_jupyter(): | ||
display(Image(url=image_url)) | ||
return image_url | ||
|
||
image_data = requests.get(image_url).content | ||
|
||
from llamabot import SimpleBot | ||
|
||
bot = SimpleBot( | ||
"You are a helpful filenaming assistant. " | ||
"Filenames should use underscores instead of spaces, " | ||
"and should be all lowercase. " | ||
"Exclude the file extension. " | ||
"Give me a compact filename for the following prompt:" | ||
) | ||
response = bot(prompt) | ||
filename = response.message | ||
if not save_path: | ||
save_path = Path(f"{filename}.jpg") | ||
with open(save_path, "wb") as file: | ||
file.write(image_data) | ||
return save_path | ||
|
||
|
||
def is_running_in_jupyter() -> bool: | ||
"""Check if running in a Jupyter notebook. | ||
:return: True if running in a Jupyter notebook, otherwise False. | ||
""" | ||
try: | ||
get_ipython() | ||
return True | ||
except NameError: | ||
return False |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Test the ImageBot class.""" | ||
from llamabot import ImageBot, SimpleBot | ||
import requests | ||
|
||
|
||
def test_initialization_defaults(): | ||
"""Test the initialization of the ImageBot class with default parameters.""" | ||
bot = ImageBot() | ||
assert bot.model == "dall-e-3" | ||
assert bot.size == "1024x1024" | ||
assert bot.quality == "hd" | ||
assert bot.n == 1 | ||
|
||
|
||
def test_initialization_custom(): | ||
"""Test the initialization of the ImageBot class with custom parameters.""" | ||
bot = ImageBot(model="custom-model", size="800x800", quality="standard", n=2) | ||
assert bot.model == "custom-model" | ||
assert bot.size == "800x800" | ||
assert bot.quality == "standard" | ||
assert bot.n == 2 | ||
|
||
|
||
def test_call_in_jupyter(mocker): | ||
"""Test the call method when running in a Jupyter notebook. | ||
This test tests that the call method returns the URL of the generated image | ||
when running in a Jupyter notebook and displays the image. | ||
:param mocker: The pytest-mock fixture. | ||
""" | ||
mocker.patch("llamabot.bot.imagebot.is_running_in_jupyter", return_value=True) | ||
mock_display = mocker.patch("llamabot.bot.imagebot.display") | ||
|
||
bot = ImageBot() | ||
|
||
# Mock the client's generate method on the instance to return the desired URL | ||
mock_response = mocker.MagicMock() | ||
mock_response.data = [mocker.MagicMock(url="http://image.url")] | ||
bot.client = mocker.MagicMock() | ||
bot.client.images.generate.return_value = mock_response | ||
|
||
result = bot("test prompt") | ||
assert result == "http://image.url" | ||
mock_display.assert_called_once() | ||
|
||
|
||
def test_call_outside_jupyter(mocker, tmp_path): | ||
"""Test the call method when not running in a Jupyter notebook. | ||
This test tests that the call method returns the path to the generated image | ||
when not running in a Jupyter notebook. | ||
:param mocker: The pytest-mock fixture. | ||
:param tmp_path: The pytest tmp_path fixture. | ||
""" | ||
# Mock the is_running_in_jupyter method | ||
mocker.patch("llamabot.bot.imagebot.is_running_in_jupyter", return_value=False) | ||
|
||
# Instantiate ImageBot | ||
bot = ImageBot() | ||
|
||
# Mock the client's generate method on the instance to return the desired URL | ||
mock_response = mocker.MagicMock() | ||
mock_response.data = [mocker.MagicMock(url="http://image.url")] | ||
bot.client = mocker.MagicMock() | ||
bot.client.images.generate.return_value = mock_response | ||
|
||
# Mock requests.get to return a mock response with content | ||
mock_get_response = mocker.MagicMock() | ||
mock_get_response.content = b"image_data" | ||
mocker.patch("requests.get", return_value=mock_get_response) | ||
|
||
# Mock the SimpleBot's __call__ method | ||
mocker.patch.object( | ||
SimpleBot, "__call__", return_value=mocker.MagicMock(message="test_prompt") | ||
) | ||
|
||
# Call the method and perform the assertion | ||
result = bot("test prompt", tmp_path / "test_prompt.jpg") | ||
assert result == tmp_path / "test_prompt.jpg" | ||
requests.get.assert_called_with("http://image.url") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters