Skip to content

Commit

Permalink
Avoid In-Memory Dataset Copy for Avalanche (#463)
Browse files Browse the repository at this point in the history
  • Loading branch information
wistuba authored Nov 13, 2023
1 parent d7b23af commit 645c219
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 27 deletions.
4 changes: 2 additions & 2 deletions src/renate/updaters/avalanche/model_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def _load_benchmark_if_exists(
train_dataset_collate_fn: Optional[Callable] = None,
val_dataset_collate_fn: Optional[Callable] = None,
) -> AvalancheBenchmarkWrapper:
train_dataset = to_avalanche_dataset(train_dataset, train_dataset_collate_fn)
avalanche_train_dataset = to_avalanche_dataset(train_dataset, train_dataset_collate_fn)

avalanche_state = None
if self._input_state_folder is not None:
Expand All @@ -232,7 +232,7 @@ def _load_benchmark_if_exists(
val_memory_dataset = to_avalanche_dataset(train_dataset, val_dataset_collate_fn)

benchmark = AvalancheBenchmarkWrapper(
train_dataset=train_dataset,
train_dataset=avalanche_train_dataset,
val_dataset=val_memory_dataset,
train_transform=self._train_transform,
train_target_transform=self._train_target_transform,
Expand Down
58 changes: 44 additions & 14 deletions src/renate/utils/avalanche.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@

from renate.data.datasets import _TransformedDataset
from renate.memory import DataBuffer
from renate.types import NestedTensors


class AvalancheDataset(Dataset):
"""A Dataset consumable by Avalanche updaters."""
class BaseAvalancheDataset(Dataset):
"""Base class for all datasets consumable by Avalanche updaters."""

def __init__(
self, inputs: NestedTensors, targets: List[int], collate_fn: Optional[Callable] = None
self,
targets: List[int],
collate_fn: Optional[Callable] = None,
):
self._inputs = inputs
self._targets = targets
self.targets = torch.tensor(targets, dtype=torch.long)
if collate_fn is not None:
Expand All @@ -29,25 +29,55 @@ def __init__(
def __len__(self) -> int:
return len(self._targets)

def __getitem__(self, idx) -> Tuple[Tensor, Tensor]:
return self._inputs[idx], self._targets[idx]

class AvalancheDataset(BaseAvalancheDataset):
"""A wrapper around a Dataset consumable by Avalanche updaters."""

def __init__(
self,
dataset: Union[Dataset, DataBuffer],
targets: List[int],
collate_fn: Optional[Callable] = None,
):
super().__init__(targets, collate_fn)
self._dataset = dataset

def __getitem__(self, idx) -> Tuple[Tensor, int]:
return self._dataset[idx][0], self._targets[idx]


class AvalancheDatasetForBuffer(BaseAvalancheDataset):
"""A wrapper around a DataBuffer consumable by Avalanche updaters."""

def __init__(
self, buffer: DataBuffer, targets: List[int], collate_fn: Optional[Callable] = None
):
super().__init__(targets, collate_fn)
self._indices = buffer._indices
self._datasets = buffer._datasets

def __getitem__(self, idx) -> Tuple[Tensor, int]:
i, j = self._indices[idx]
return self._datasets[i][j][0], self._targets[idx]


def to_avalanche_dataset(
dataset: Union[Dataset, DataBuffer], collate_fn: Optional[Callable] = None
) -> AvalancheDataset:
) -> BaseAvalancheDataset:
"""Converts a DataBuffer or Dataset into an Avalanche-compatible Dataset."""
x_data, y_data = [], []
y_data = []
is_buffer = isinstance(dataset, DataBuffer)
for i in range(len(dataset)):
if isinstance(dataset, DataBuffer):
(x, y), _ = dataset[i]
if is_buffer:
(_, y), _ = dataset[i]
else:
x, y = dataset[i]
x_data.append(x)
_, y = dataset[i]
if not isinstance(y, int):
y = y.item()
y_data.append(y)
return AvalancheDataset(x_data, y_data, collate_fn)
if is_buffer:
return AvalancheDatasetForBuffer(dataset, y_data, collate_fn)
return AvalancheDataset(dataset, y_data, collate_fn)


class AvalancheBenchmarkWrapper:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
"dataset": "mnist.json",
"backend": "local",
"job_name": "rotation-mlp-avalanche-ewc",
"expected_accuracy_linux": [[0.7580999732017517, 0.9627000093460083], [0.7551000118255615, 0.9664999842643738]],
"expected_accuracy_darwin": [[0.7497000098228455, 0.9664999842643738]]
"expected_accuracy_linux": [[0.7448999881744385, 0.9642000198364258], [0.7483000159263611, 0.9634000062942505]],
"expected_accuracy_darwin": [[0.7275999784469604, 0.9585000276565552]]
}
27 changes: 18 additions & 9 deletions test/renate/utils/test_avalanche.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import pickle

import pytest
import torch
from avalanche.training.plugins import EWCPlugin, ReplayPlugin
from torch import Tensor
from torch.utils.data import Subset, TensorDataset

from renate.memory import ReservoirBuffer
from renate.utils.avalanche import (
AvalancheBenchmarkWrapper,
_plugin_index,
Expand All @@ -16,24 +19,30 @@
)


def test_to_avalanche_dataset():
@pytest.mark.parametrize("use_buffer", [True, False])
def test_to_avalanche_dataset(use_buffer):
expected_x = 6
expected_y = 1
tensor_dataset = TensorDataset(
torch.tensor([5, expected_x, 7]), torch.tensor([0, expected_y, 2])
)
dataset = to_avalanche_dataset(Subset(tensor_dataset, [1]))
assert dataset._inputs[0].item() == expected_x
if use_buffer:
buffer = ReservoirBuffer(10)
buffer.update(tensor_dataset)
dataset = to_avalanche_dataset(buffer)
else:
dataset = to_avalanche_dataset(Subset(tensor_dataset, [0, 1, 2]))
assert type(dataset._targets) == list
assert len(dataset._targets) == 1
assert dataset._targets[0] == expected_y
assert len(dataset._targets) == 3
assert dataset._targets[1] == expected_y
assert type(dataset._targets[0]) == int
assert dataset.targets.item() == dataset._targets[0]
assert dataset.targets == expected_y
assert dataset.targets[0].item() == dataset._targets[0]
assert dataset.targets[1] == expected_y
assert type(dataset.targets) == Tensor
x, y = dataset[0]
x, y = dataset[1]
assert x == expected_x and y == expected_y
assert len(dataset) == 1
assert len(dataset) == 3
pickle.dumps(dataset)


def test_avalanche_benchmark_wrapper_correctly_tracks_and_saves_state():
Expand Down

0 comments on commit 645c219

Please sign in to comment.