Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

To support io.Byteio #337

Closed
wants to merge 3 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 60 additions & 51 deletions keras_preprocessing/image/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import io
import os
import warnings
from pathlib import Path

import numpy as np

Expand Down Expand Up @@ -77,7 +78,7 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
"""Loads an image into PIL format.

# Arguments
path: Path to image file.
path: Path (string), pathlib.Path object, or io.BytesIO stream to image file.
grayscale: DEPRECATED use `color_mode="grayscale"`.
color_mode: The desired image format. One of "grayscale", "rgb", "rgba".
"grayscale" supports 8-bit images and 32-bit signed integer images.
Expand All @@ -101,6 +102,7 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
# Raises
ImportError: if PIL is not available.
ValueError: if interpolation method is not supported.
TypeError: type of 'path' should be path-like or io.Byteio.
"""
if grayscale is True:
warnings.warn('grayscale is deprecated. Please use '
Expand All @@ -109,56 +111,63 @@ def load_img(path, grayscale=False, color_mode='rgb', target_size=None,
if pil_image is None:
raise ImportError('Could not import PIL.Image. '
'The use of `load_img` requires PIL.')
with open(path, 'rb') as f:
img = pil_image.open(io.BytesIO(f.read()))
if color_mode == 'grayscale':
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
# convert it to an 8-bit grayscale image.
if img.mode not in ('L', 'I;16', 'I'):
img = img.convert('L')
elif color_mode == 'rgba':
if img.mode != 'RGBA':
img = img.convert('RGBA')
elif color_mode == 'rgb':
if img.mode != 'RGB':
img = img.convert('RGB')
else:
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]

if keep_aspect_ratio:
width, height = img.size
target_width, target_height = width_height_tuple

crop_height = (width * target_height) // target_width
crop_width = (height * target_width) // target_height

# Set back to input height / width
# if crop_height / crop_width is not smaller.
crop_height = min(height, crop_height)
crop_width = min(width, crop_width)

crop_box_hstart = (height - crop_height) // 2
crop_box_wstart = (width - crop_width) // 2
crop_box_wend = crop_box_wstart + crop_width
crop_box_hend = crop_box_hstart + crop_height
crop_box = [crop_box_wstart, crop_box_hstart,
crop_box_wend, crop_box_hend]

img = img.resize(width_height_tuple, resample,
box=crop_box)
else:
img = img.resize(width_height_tuple, resample)
return img
if isinstance(path, io.BytesIO):
img = pil_image.open(path)
elif isinstance(path, (Path, bytes, str)):
if isinstance(path, Path):
path = str(path.resolve())
with open(path, 'rb') as f:
img = pil_image.open(io.BytesIO(f.read()))
else:
raise TypeError('path should be path-like or io.BytesIO'
', not {}'.format(type(path)))

if color_mode == 'grayscale':
# if image is not already an 8-bit, 16-bit or 32-bit grayscale image
# convert it to an 8-bit grayscale image.
if img.mode not in ('L', 'I;16', 'I'):
img = img.convert('L')
elif color_mode == 'rgba':
if img.mode != 'RGBA':
img = img.convert('RGBA')
elif color_mode == 'rgb':
if img.mode != 'RGB':
img = img.convert('RGB')
else:
raise ValueError('color_mode must be "grayscale", "rgb", or "rgba"')
if target_size is not None:
width_height_tuple = (target_size[1], target_size[0])
if img.size != width_height_tuple:
if interpolation not in _PIL_INTERPOLATION_METHODS:
raise ValueError(
'Invalid interpolation method {} specified. Supported '
'methods are {}'.format(
interpolation,
", ".join(_PIL_INTERPOLATION_METHODS.keys())))
resample = _PIL_INTERPOLATION_METHODS[interpolation]

if keep_aspect_ratio:
width, height = img.size
target_width, target_height = width_height_tuple

crop_height = (width * target_height) // target_width
crop_width = (height * target_width) // target_height

# Set back to input height / width
# if crop_height / crop_width is not smaller.
crop_height = min(height, crop_height)
crop_width = min(width, crop_width)

crop_box_hstart = (height - crop_height) // 2
crop_box_wstart = (width - crop_width) // 2
crop_box_wend = crop_box_wstart + crop_width
crop_box_hend = crop_box_hstart + crop_height
crop_box = [crop_box_wstart, crop_box_hstart,
crop_box_wend, crop_box_hend]
img = img.resize(width_height_tuple, resample, box=crop_box)
else:
img = img.resize(width_height_tuple, resample)
return img


def list_pictures(directory, ext=('jpg', 'jpeg', 'bmp', 'png', 'ppm', 'tif',
Expand Down