Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

Add balance in flow_from_directory to handle data imbalance (using random oversampling) #310

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
16 changes: 14 additions & 2 deletions keras_preprocessing/image/directory_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@
import numpy as np

from .iterator import BatchFromFilesMixin, Iterator
from .utils import _list_valid_filenames_in_directory
from .utils import (_list_valid_filenames_in_directory,
_make_balance_config)


class DirectoryIterator(BatchFromFilesMixin, Iterator):
Expand All @@ -25,6 +26,7 @@ class DirectoryIterator(BatchFromFilesMixin, Iterator):
via the `classes` argument.
image_data_generator: Instance of `ImageDataGenerator`
to use for random transformations and normalization.
balance: Boolean, will handle data imbalance if set to True.
target_size: tuple of integers, dimensions to resize input images to.
color_mode: One of `"rgb"`, `"rgba"`, `"grayscale"`.
Color mode to read images.
Expand Down Expand Up @@ -76,6 +78,7 @@ def __new__(cls, *args, **kwargs):
def __init__(self,
directory,
image_data_generator,
balance=False,
target_size=(256, 256),
color_mode='rgb',
classes=None,
Expand All @@ -101,6 +104,14 @@ def __init__(self,
subset,
interpolation)
self.directory = directory

if balance and subset != 'validation':
validation_split = image_data_generator._validation_split
self._balance_config = _make_balance_config(directory,
validation_split)
else:
self._balance_config = None

self.classes = classes
if class_mode not in self.allowed_class_modes:
raise ValueError('Invalid class_mode: {}; expected one of: {}'
Expand Down Expand Up @@ -129,7 +140,8 @@ def __init__(self,
results.append(
pool.apply_async(_list_valid_filenames_in_directory,
(dirpath, self.white_list_formats, self.split,
self.class_indices, follow_links)))
self.class_indices, follow_links,
self._balance_config)))
classes_list = []
for res in results:
classes, filenames = res.get()
Expand Down
3 changes: 3 additions & 0 deletions keras_preprocessing/image/image_data_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def flow(self,

def flow_from_directory(self,
directory,
balance=False,
target_size=(256, 256),
color_mode='rgb',
classes=None,
Expand All @@ -465,6 +466,7 @@ def flow_from_directory(self,
See [this script](
https://gist.github.com/fchollet/0830affa1f7f19fd47b06d4cf89ed44d)
for more details.
balance: Boolean, handles data imbalance if True
target_size: Tuple of integers `(height, width)`,
default: `(256, 256)`.
The dimensions to which all images found will be resized.
Expand Down Expand Up @@ -531,6 +533,7 @@ class subdirectories (default: False).
return DirectoryIterator(
directory,
self,
balance=balance,
target_size=target_size,
color_mode=color_mode,
classes=classes,
Expand Down
72 changes: 71 additions & 1 deletion keras_preprocessing/image/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io
import os
import warnings
import random

import numpy as np

Expand Down Expand Up @@ -182,8 +183,24 @@ def _recursive_list(subpath):
yield root, fname


def _settle_debt(list_valid_files, debt):
"""Iterates over list_valid_files and resamples to settle debt.

# Arguments:
list_valid_files: List of strings, list that contains valid filenames
debt: Integer, required number of samples to be resampled from
valid_file_names

# Yields:
randomly chosen filename from list_valid_files
"""
for i in range(debt):
yield random.choice(list_valid_files)


def _list_valid_filenames_in_directory(directory, white_list_formats, split,
class_indices, follow_links):
class_indices, follow_links,
balance_config=None):
"""Lists paths of files in `subdir` with extensions in `white_list_formats`.

# Arguments
Expand All @@ -198,6 +215,7 @@ def _list_valid_filenames_in_directory(directory, white_list_formats, split,
of images in each directory.
class_indices: dictionary mapping a class name to its index.
follow_links: boolean, follow symbolic links to subdirectories.
balance_config: dict, stores configurations for handling data imbalance.

# Returns
classes: a list of class indices
Expand Down Expand Up @@ -225,9 +243,61 @@ def _list_valid_filenames_in_directory(directory, white_list_formats, split,
dirname, os.path.relpath(absolute_path, directory))
filenames.append(relative_path)

if balance_config:
filenames_copy = filenames.copy()

debt = balance_config['majority'] - len(filenames_copy)

for filename in _settle_debt(filenames_copy, debt):
classes.append(class_indices[dirname])
filenames.append(filename)

return classes, filenames


def _generate_class_count(directory):
"""Maintain sample count of each class in the directory.

# Arguments
directory: string, absolute path to the directory
# Returns
class_count: dictionary, sample count for each class
"""

class_count = {}

for category in os.listdir(directory):
category_directory = os.path.join(directory, category)
class_count[category] = len(os.listdir(category_directory))

return class_count


def _make_balance_config(directory, validation_split):
"""Scans the directory to make a config dictionary to handle data imbalance.

# Arguments
directory: string, absolute path to the directory
validation_split: float, validation split
Default: None
# Returns
balance_config: dictionary, specs needed to handle data imbalance
'majority': integer, number of samples in the majority class
"""
class_count = _generate_class_count(directory)

# Get the sample count of the majority class
majority_class_count = class_count[max(class_count, key=class_count.get)]
if validation_split:
majority_class_count = int(majority_class_count*(1 - validation_split)) + 1

balance_config = {
'majority': majority_class_count
}

return balance_config


def array_to_img(x, data_format='channels_last', scale=True, dtype='float32'):
"""Converts a 3D Numpy array to a PIL Image instance.

Expand Down