diff --git a/keras_preprocessing/image/directory_iterator.py b/keras_preprocessing/image/directory_iterator.py index 3a829b4f..11fef919 100644 --- a/keras_preprocessing/image/directory_iterator.py +++ b/keras_preprocessing/image/directory_iterator.py @@ -11,7 +11,8 @@ import numpy as np from .iterator import BatchFromFilesMixin, Iterator -from .utils import _list_valid_filenames_in_directory +from .utils import (_list_valid_filenames_in_directory, + _make_balance_config) class DirectoryIterator(BatchFromFilesMixin, Iterator): @@ -25,6 +26,7 @@ class DirectoryIterator(BatchFromFilesMixin, Iterator): via the `classes` argument. image_data_generator: Instance of `ImageDataGenerator` to use for random transformations and normalization. + balance: Boolean, will handle data imbalance if set to True. target_size: tuple of integers, dimensions to resize input images to. color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`. Color mode to read images. @@ -76,6 +78,7 @@ def __new__(cls, *args, **kwargs): def __init__(self, directory, image_data_generator, + balance=False, target_size=(256, 256), color_mode='rgb', classes=None, @@ -101,6 +104,14 @@ def __init__(self, subset, interpolation) self.directory = directory + + if balance and subset != 'validation': + validation_split = image_data_generator._validation_split + self._balance_config = _make_balance_config(directory, + validation_split) + else: + self._balance_config = None + self.classes = classes if class_mode not in self.allowed_class_modes: raise ValueError('Invalid class_mode: {}; expected one of: {}' @@ -129,7 +140,8 @@ def __init__(self, results.append( pool.apply_async(_list_valid_filenames_in_directory, (dirpath, self.white_list_formats, self.split, - self.class_indices, follow_links))) + self.class_indices, follow_links, + self._balance_config))) classes_list = [] for res in results: classes, filenames = res.get() diff --git a/keras_preprocessing/image/image_data_generator.py b/keras_preprocessing/image/image_data_generator.py index 77d88147..a978166e 100644 --- a/keras_preprocessing/image/image_data_generator.py +++ b/keras_preprocessing/image/image_data_generator.py @@ -441,6 +441,7 @@ def flow(self, def flow_from_directory(self, directory, + balance=False, target_size=(256, 256), color_mode='rgb', classes=None, @@ -465,6 +466,7 @@ def flow_from_directory(self, See [this script]( https://gist.github.com/fchollet/0830affa1f7f19fd47b06d4cf89ed44d) for more details. + balance: Boolean, handles data imbalance if True target_size: Tuple of integers `(height, width)`, default: `(256, 256)`. The dimensions to which all images found will be resized. @@ -531,6 +533,7 @@ class subdirectories (default: False). return DirectoryIterator( directory, self, + balance=balance, target_size=target_size, color_mode=color_mode, classes=classes, diff --git a/keras_preprocessing/image/utils.py b/keras_preprocessing/image/utils.py index bc3e6886..30f67426 100644 --- a/keras_preprocessing/image/utils.py +++ b/keras_preprocessing/image/utils.py @@ -7,6 +7,7 @@ import io import os import warnings +import random import numpy as np @@ -182,8 +183,24 @@ def _recursive_list(subpath): yield root, fname +def _settle_debt(list_valid_files, debt): + """Iterates over list_valid_files and resamples to settle debt. + + # Arguments: + list_valid_files: List of strings, list that contains valid filenames + debt: Integer, required number of samples to be resampled from + valid_file_names + + # Yields: + randomly chosen filename from list_valid_files + """ + for i in range(debt): + yield random.choice(list_valid_files) + + def _list_valid_filenames_in_directory(directory, white_list_formats, split, - class_indices, follow_links): + class_indices, follow_links, + balance_config=None): """Lists paths of files in `subdir` with extensions in `white_list_formats`. # Arguments @@ -198,6 +215,7 @@ def _list_valid_filenames_in_directory(directory, white_list_formats, split, of images in each directory. class_indices: dictionary mapping a class name to its index. follow_links: boolean, follow symbolic links to subdirectories. + balance_config: dict, stores configurations for handling data imbalance. # Returns classes: a list of class indices @@ -225,9 +243,61 @@ def _list_valid_filenames_in_directory(directory, white_list_formats, split, dirname, os.path.relpath(absolute_path, directory)) filenames.append(relative_path) + if balance_config: + filenames_copy = filenames.copy() + + debt = balance_config['majority'] - len(filenames_copy) + + for filename in _settle_debt(filenames_copy, debt): + classes.append(class_indices[dirname]) + filenames.append(filename) + return classes, filenames +def _generate_class_count(directory): + """Maintain sample count of each class in the directory. + + # Arguments + directory: string, absolute path to the directory + # Returns + class_count: dictionary, sample count for each class + """ + + class_count = {} + + for category in os.listdir(directory): + category_directory = os.path.join(directory, category) + class_count[category] = len(os.listdir(category_directory)) + + return class_count + + +def _make_balance_config(directory, validation_split): + """Scans the directory to make a config dictionary to handle data imbalance. + + # Arguments + directory: string, absolute path to the directory + validation_split: float, validation split + Default: None + # Returns + balance_config: dictionary, specs needed to handle data imbalance + 'majority': integer, number of samples in the majority class + """ + class_count = _generate_class_count(directory) + + # Get the sample count of the majority class + majority_class_count = class_count[max(class_count, key=class_count.get)] + if validation_split: + majority_class_count = int(majority_class_count*(1 - validation_split)) + 1 + + balance_config = { + 'majority': majority_class_count + } + + return balance_config + + def array_to_img(x, data_format='channels_last', scale=True, dtype='float32'): """Converts a 3D Numpy array to a PIL Image instance.