-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathMNIST.py
116 lines (99 loc) · 3.48 KB
/
MNIST.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
from typing import *
import pytorch_lightning as pl
from torch.utils.data import DataLoader, Subset, Dataset
from torchvision.datasets import MNIST, FashionMNIST, EMNIST, KMNIST
from sklearn.model_selection import train_test_split
import numpy as np
class MnistDataModuleBase(pl.LightningDataModule):
def __init__(
self,
DATASET: Dataset,
root_dir: str,
train_transforms: Callable,
val_transforms: Callable,
test_transforms: Callable,
batch_size: int,
num_workers: int,
):
super().__init__()
self.save_hyperparameters(
{
"root_dir": root_dir,
"batch_size": batch_size,
"num_workers": num_workers,
},
)
self.Dataset = DATASET
self.train_transforms = train_transforms
self.val_transforms = val_transforms
self.test_transforms = test_transforms
def prepare_data(self) -> None:
"""Dataset download"""
self.Dataset(self.hparams.root_dir, train=True, download=True)
self.Dataset(self.hparams.root_dir, train=False, download=True)
def setup(self, stage: Optional[str] = None) -> None:
if stage == "fit" or stage is None:
# split dataset to train, val
ds = self.Dataset(self.hparams.root_dir, train=True, download=False)
targets = ds.targets
train_idx, val_idx = train_test_split(
np.arange(len(targets)),
test_size=0.2,
shuffle=True,
stratify=targets,
)
# build dataset, different transforms
self.train_ds = Subset(
self.Dataset(
self.hparams.root_dir,
train=True,
transform=self.train_transforms,
download=False,
),
train_idx,
)
self.val_ds = Subset(
self.Dataset(
self.hparams.root_dir,
train=True,
transform=self.val_transforms,
download=False,
),
val_idx,
)
if stage == "test" or stage is None:
self.test_ds = MNIST(
self.hparams.root_dir,
train=False,
transform=self.test_transforms,
)
def train_dataloader(self) -> DataLoader:
return DataLoader(
self.train_ds,
batch_size=self.hparams.batch_size,
shuffle=True,
num_workers=self.hparams.num_workers,
)
def val_dataloader(self) -> DataLoader:
return DataLoader(
self.val_ds,
batch_size=self.hparams.batch_size,
shuffle=False,
num_workers=self.hparams.num_workers,
)
def test_dataloader(self) -> DataLoader:
return DataLoader(
self.test_ds,
batch_size=self.hparams.batch_size,
shuffle=False,
num_workers=self.hparams.num_workers,
)
def MnistDataModule(**kwargs):
return MnistDataModuleBase(MNIST, **kwargs)
def FashionMnistDataModule(**kwargs):
return MnistDataModuleBase(FashionMNIST, **kwargs)
def EmnistDataModule(**kwargs):
DATASET = lambda root, **kwargs: EMNIST(root, "byclass", **kwargs)
return MnistDataModuleBase(DATASET, **kwargs)
def KMnistDataModule(**kwargs):
return MnistDataModuleBase(KMNIST, **kwargs)