Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Commit

Permalink
To support io.Bytesio (#339)
Browse files Browse the repository at this point in the history
* To support io.Bytesio

* fix doc error

* fix format

* add unit test

* fix assert condition

* fix np ndarray compare in an error way

* fix type error, cast the LocalPath to pathlib.Path
  • Loading branch information
TreeKat71 authored Feb 4, 2021
1 parent 2310b4e commit 6701f27
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 51 deletions.
113 changes: 62 additions & 51 deletions keras_preprocessing/image/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
import os
import warnings
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -77,7 +78,7 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
"""Loads an image into PIL format.
# Arguments
path: Path to image file.
path: Path (string), pathlib.Path object, or io.BytesIO stream to image file.
grayscale: DEPRECATED use `color_mode="grayscale"`.
color_mode: The desired image format. One of "grayscale", "rgb", "rgba".
"grayscale" supports 8-bit images and 32-bit signed integer images.
Expand All @@ -101,6 +102,7 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
# Raises
ImportError: if PIL is not available.
ValueError: if interpolation method is not supported.
TypeError: type of 'path' should be path-like or io.Byteio.
"""
if grayscale is True:
warnings.warn('grayscale is deprecated. Please use '
Expand All @@ -109,56 +111,65 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
if pil_image is None:
raise ImportError('Could not import PIL.Image. '
'The use of `load_img` requires PIL.')
with open(path, 'rb') as f:
img = pil_image.open(io.BytesIO(f.read()))
if color_mode == 'grayscale':
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
# convert it to an 8-bit grayscale image.
if img.mode not in ('L', 'I;16', 'I'):
img = img.convert('L')
elif color_mode == 'rgba':
if img.mode != 'RGBA':
img = img.convert('RGBA')
elif color_mode == 'rgb':
if img.mode != 'RGB':
img = img.convert('RGB')
else:
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]

if keep_aspect_ratio:
width, height = img.size
target_width, target_height = width_height_tuple

crop_height = (width * target_height) // target_width
crop_width = (height * target_width) // target_height

# Set back to input height / width
# if crop_height / crop_width is not smaller.
crop_height = min(height, crop_height)
crop_width = min(width, crop_width)

crop_box_hstart = (height - crop_height) // 2
crop_box_wstart = (width - crop_width) // 2
crop_box_wend = crop_box_wstart + crop_width
crop_box_hend = crop_box_hstart + crop_height
crop_box = [crop_box_wstart, crop_box_hstart,
crop_box_wend, crop_box_hend]

img = img.resize(width_height_tuple, resample,
box=crop_box)
else:
img = img.resize(width_height_tuple, resample)
return img
if isinstance(path, io.BytesIO):
img = pil_image.open(path)
elif isinstance(path, (Path, bytes, str)):
if isinstance(path, Path):
path = str(path.resolve())
with open(path, 'rb') as f:
img = pil_image.open(io.BytesIO(f.read()))
else:
raise TypeError('path should be path-like or io.BytesIO'
', not {}'.format(type(path)))

if color_mode == 'grayscale':
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
# convert it to an 8-bit grayscale image.
if img.mode not in ('L', 'I;16', 'I'):
img = img.convert('L')
elif color_mode == 'rgba':
if img.mode != 'RGBA':
img = img.convert('RGBA')
elif color_mode == 'rgb':
if img.mode != 'RGB':
img = img.convert('RGB')
else:
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]

if keep_aspect_ratio:
width, height = img.size
target_width, target_height = width_height_tuple

crop_height = (width * target_height) // target_width
crop_width = (height * target_width) // target_height

# Set back to input height / width
# if crop_height / crop_width is not smaller.
crop_height = min(height, crop_height)
crop_width = min(width, crop_width)

crop_box_hstart = (height - crop_height) // 2
crop_box_wstart = (width - crop_width) // 2
crop_box_wend = crop_box_wstart + crop_width
crop_box_hend = crop_box_hstart + crop_height
crop_box = [
crop_box_wstart, crop_box_hstart, crop_box_wend,
crop_box_hend
]
img = img.resize(width_height_tuple, resample, box=crop_box)
else:
img = img.resize(width_height_tuple, resample)
return img


def list_pictures(directory, ext=('jpg', 'jpeg', 'bmp', 'png', 'ppm', 'tif',
Expand Down
24 changes: 24 additions & 0 deletions tests/image/utils_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import io
import resource
from pathlib import Path

import numpy as np
import PIL
Expand Down Expand Up @@ -193,6 +195,28 @@ def test_load_img(tmpdir):
loaded_im_array = utils.img_to_array(loaded_im, dtype='int32')
assert loaded_im_array.shape == (25, 25, 1)

# Test different path type
with open(filename_grayscale_32bit, 'rb') as f:
_path = io.BytesIO(f.read()) # io.Bytesio
loaded_im = utils.load_img(_path, color_mode='grayscale')
loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32)
assert np.all(loaded_im_array == original_grayscale_32bit_array)

_path = filename_grayscale_32bit # str
loaded_im = utils.load_img(_path, color_mode='grayscale')
loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32)
assert np.all(loaded_im_array == original_grayscale_32bit_array)

_path = filename_grayscale_32bit.encode() # bytes
loaded_im = utils.load_img(_path, color_mode='grayscale')
loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32)
assert np.all(loaded_im_array == original_grayscale_32bit_array)

_path = Path(tmpdir / 'grayscale_32bit_utils.tiff') # Path
loaded_im = utils.load_img(_path, color_mode='grayscale')
loaded_im_array = utils.img_to_array(loaded_im, dtype=np.int32)
assert np.all(loaded_im_array == original_grayscale_32bit_array)

# Check that exception is raised if interpolation not supported.

loaded_im = utils.load_img(filename_rgb, interpolation="unsupported")
Expand Down

0 comments on commit 6701f27

Please sign in to comment.