-
Notifications
You must be signed in to change notification settings - Fork 4
/
dataset.py
59 lines (46 loc) · 1.68 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from __future__ import print_function
from __future__ import absolute_import
from __future__ import division
# 3rd party:
from tensorflow.keras import datasets
from dataclasses import dataclass
import numpy as np
import cv2
@dataclass
class Dataset:
name: str
train: np.ndarray
val: np.ndarray
test: np.ndarray
def __str__(self):
res = 'Dataset:\n'
for k, v in vars(self).items():
val = v if isinstance(v, str) else (v.shape, v.dtype)
res += f'o {k:10}|{val}\n'
return res
@classmethod
def get_cifar10(cls):
cifar10 = datasets.cifar10.load_data()
(x_train, _), (x_test, _) = cifar10
num_val = x_train.shape[0] // 10
return Dataset('cifar10',
np.float32(x_train[num_val:] / 255.),
np.float32(x_train[:num_val] / 255.),
np.float32(x_test / 255.))
@classmethod
def get_mnist64(cls):
mnist = datasets.mnist.load_data()
(x_train, _), (x_test, _) = mnist
train_size = x_train.shape[0]
train_64 = np.array([cv2.resize(x_train[i], (64, 64)) for i in range(train_size)])
np.random.shuffle(train_64)
test_size = x_test.shape[0]
test_64 = np.array([cv2.resize(x_test[i], (64, 64)) for i in range(test_size)])
num_val = x_train.shape[0] // 10
return Dataset('mnist64',
np.float32(train_64[num_val:, :, :, None] / 255.),
np.float32(train_64[:num_val, :, :, None] / 255.),
np.float32(test_64[:, :, :, None] / 255.))
if __name__ == '__main__':
mnist_64 = Dataset.get_mnist64()
print(mnist_64)