diff --git a/continuum/tasks/__init__.py b/continuum/tasks/__init__.py index 73bef609..a9c558e0 100644 --- a/continuum/tasks/__init__.py +++ b/continuum/tasks/__init__.py @@ -2,6 +2,6 @@ # flake8: noqa from continuum.tasks.task_set import TaskSet from continuum.tasks.base import TaskType -from continuum.tasks.utils import split_train_val, concat +from continuum.tasks.utils import split_train_val, concat, get_balanced_sampler __all__ = ["TaskSet", "TaskType"] diff --git a/continuum/tasks/utils.py b/continuum/tasks/utils.py index 00cbee68..d291560c 100644 --- a/continuum/tasks/utils.py +++ b/continuum/tasks/utils.py @@ -1,13 +1,38 @@ from typing import Tuple, List - import torch import numpy as np -from continuum.tasks.base import BaseTaskSet +from continuum.tasks.base import BaseTaskSet, TaskType from continuum.tasks.task_set import TaskSet +def get_balanced_sampler(taskset, log=False): + """Create a sampler that will balance the dataset. + + You should give the returned sampler to the dataloader with the argument `sampler`. + + :param taskset: A pytorch dataset that implement the TaskSet interface. + :param log: Use a log weights. If enabled, there will still be imbalance but + on the other hand, the oversampling/downsampling won't be as violent. + :return: A PyTorch sampler. + """ + if taskset.data_type in (TaskType.SEGMENTATION, TaskType.OBJ_DETECTION, TaskType.TEXT): + raise NotImplementedError( + "Samplers are not yet available for the " + f"{taskset.data_type} type." + ) + + y = taskset.get_raw_samples()[1] + nb_per_class = np.bincount(y) + weights_per_class = 1 / nb_per_class + if log: + weights_per_class = np.log(weights_per_class) + weights_per_class = 1 - (weights_per_class / np.sum(weights_per_class)) + + weights = weights_per_class[y] + + return torch.utils.data.sampler.WeightedRandomSampler(weights, len(taskset)) def split_train_val(dataset: BaseTaskSet, val_split: float = 0.1) -> Tuple[BaseTaskSet, BaseTaskSet]: diff --git a/tests/test_taskset.py b/tests/test_taskset.py index ffbcaf99..6ecf9725 100644 --- a/tests/test_taskset.py +++ b/tests/test_taskset.py @@ -1,8 +1,31 @@ import numpy as np import pytest -from continuum.datasets import InMemoryDataset -from continuum.tasks import TaskSet, concat, split_train_val from torch.utils.data import DataLoader +import torch + +from continuum.datasets import InMemoryDataset +from continuum.tasks import TaskSet, concat, split_train_val, get_balanced_sampler + + +@pytest.mark.parametrize("log", [False, True]) +def test_sampler_function(log): + np.random.seed(1) + torch.manual_seed(1) + + x = np.random.rand(100, 2, 2, 3) + y = np.ones((100,), dtype=np.int64) + y[0] = 0 + t = np.ones((100,)) + + taskset = TaskSet(x, y, t, None) + sampler = get_balanced_sampler(taskset, log=log) + + loader = DataLoader(taskset, sampler=sampler, batch_size=1) + nb_0 = 0 + for x, y, t in loader: + if 0 in y: + nb_0 += 1 + assert nb_0 > 1 @pytest.mark.parametrize("nb_others", [1, 2])