Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

To support io.Bytesio #339

Merged
merged 7 commits into from
Feb 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 62 additions & 51 deletions keras_preprocessing/image/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
import os
import warnings
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -77,7 +78,7 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
"""Loads an image into PIL format.

# Arguments
path: Path to image file.
path: Path (string), pathlib.Path object, or io.BytesIO stream to image file.
grayscale: DEPRECATED use `color_mode="grayscale"`.
color_mode: The desired image format. One of "grayscale", "rgb", "rgba".
"grayscale" supports 8-bit images and 32-bit signed integer images.
Expand All @@ -101,6 +102,7 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
# Raises
ImportError: if PIL is not available.
ValueError: if interpolation method is not supported.
TypeError: type of 'path' should be path-like or io.Byteio.
"""
if grayscale is True:
warnings.warn('grayscale is deprecated. Please use '
Expand All @@ -109,56 +111,65 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
if pil_image is None:
raise ImportError('Could not import PIL.Image. '
'The use of `load_img` requires PIL.')
with open(path, 'rb') as f:
img = pil_image.open(io.BytesIO(f.read()))
if color_mode == 'grayscale':
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
# convert it to an 8-bit grayscale image.
if img.mode not in ('L', 'I;16', 'I'):
img = img.convert('L')
elif color_mode == 'rgba':
if img.mode != 'RGBA':
img = img.convert('RGBA')
elif color_mode == 'rgb':
if img.mode != 'RGB':
img = img.convert('RGB')
else:
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]

if keep_aspect_ratio:
width, height = img.size
target_width, target_height = width_height_tuple

crop_height = (width * target_height) // target_width
crop_width = (height * target_width) // target_height

# Set back to input height / width
# if crop_height / crop_width is not smaller.
crop_height = min(height, crop_height)
crop_width = min(width, crop_width)

crop_box_hstart = (height - crop_height) // 2
crop_box_wstart = (width - crop_width) // 2
crop_box_wend = crop_box_wstart + crop_width
crop_box_hend = crop_box_hstart + crop_height
crop_box = [crop_box_wstart, crop_box_hstart,
crop_box_wend, crop_box_hend]

img = img.resize(width_height_tuple, resample,
box=crop_box)
else:
img = img.resize(width_height_tuple, resample)
return img
if isinstance(path, io.BytesIO):
img = pil_image.open(path)
elif isinstance(path, (Path, bytes, str)):
if isinstance(path, Path):
path = str(path.resolve())
with open(path, 'rb') as f:
img = pil_image.open(io.BytesIO(f.read()))
else:
raise TypeError('path should be path-like or io.BytesIO'
', not {}'.format(type(path)))

if color_mode == 'grayscale':
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
# convert it to an 8-bit grayscale image.
if img.mode not in ('L', 'I;16', 'I'):
img = img.convert('L')
elif color_mode == 'rgba':
if img.mode != 'RGBA':
img = img.convert('RGBA')
elif color_mode == 'rgb':
if img.mode != 'RGB':
img = img.convert('RGB')
else:
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]

if keep_aspect_ratio:
width, height = img.size
target_width, target_height = width_height_tuple

crop_height = (width * target_height) // target_width
crop_width = (height * target_width) // target_height

# Set back to input height / width
# if crop_height / crop_width is not smaller.
crop_height = min(height, crop_height)
crop_width = min(width, crop_width)

crop_box_hstart = (height - crop_height) // 2
crop_box_wstart = (width - crop_width) // 2
crop_box_wend = crop_box_wstart + crop_width
crop_box_hend = crop_box_hstart + crop_height
crop_box = [
crop_box_wstart, crop_box_hstart, crop_box_wend,
crop_box_hend
]
img = img.resize(width_height_tuple, resample, box=crop_box)
else:
img = img.resize(width_height_tuple, resample)
return img


def list_pictures(directory, ext=('jpg', 'jpeg', 'bmp', 'png', 'ppm', 'tif',
Expand Down
24 changes: 24 additions & 0 deletions tests/image/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
import resource
from pathlib import Path

import numpy as np
import PIL
Expand Down Expand Up @@ -193,6 +195,28 @@ def test_load_img(tmpdir):
loaded_im_array = utils.img_to_array(loaded_im, dtype='int32')
assert loaded_im_array.shape == (25, 25, 1)

# Test different path type
with open(filename_grayscale_32bit, 'rb') as f:
_path = io.BytesIO(f.read()) # io.Bytesio
loaded_im = utils.load_img(_path, color_mode='grayscale')
loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32)
assert np.all(loaded_im_array == original_grayscale_32bit_array)

_path = filename_grayscale_32bit # str
loaded_im = utils.load_img(_path, color_mode='grayscale')
loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32)
assert np.all(loaded_im_array == original_grayscale_32bit_array)

_path = filename_grayscale_32bit.encode() # bytes
loaded_im = utils.load_img(_path, color_mode='grayscale')
loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32)
assert np.all(loaded_im_array == original_grayscale_32bit_array)

_path = Path(tmpdir / 'grayscale_32bit_utils.tiff') # Path
loaded_im = utils.load_img(_path, color_mode='grayscale')
loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32)
assert np.all(loaded_im_array == original_grayscale_32bit_array)

# Check that exception is raised if interpolation not supported.

loaded_im = utils.load_img(filename_rgb, interpolation="unsupported")
Expand Down