-
Notifications
You must be signed in to change notification settings - Fork 5
/
dataloader.py
43 lines (34 loc) · 1.51 KB
/
dataloader.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from torchvision import transforms
from torch.utils.data import DataLoader
import numpy as np
import torch
import torchvision
import matplotlib.pyplot as plt
def load_transformed_dataset(img_size=64, batch_size=128) -> DataLoader:
# Load dataset and perform data transformations
data_transforms = [
transforms.Resize((img_size, img_size)),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(), # Scales data into [0,1]
transforms.Lambda(lambda t: (t * 2) - 1), # Scale between [-1, 1]
]
data_transform = transforms.Compose(data_transforms)
train = torchvision.datasets.ImageFolder(root="./stanford_cars/car_data/car_data/train", transform=data_transform)
test = torchvision.datasets.ImageFolder(root="./stanford_cars/car_data/car_data/test", transform=data_transform)
dataset = torch.utils.data.ConcatDataset([train, test])
return DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
def show_tensor_image(image):
# Reverse the data transformations
reverse_transforms = transforms.Compose(
[
transforms.Lambda(lambda t: (t + 1) / 2),
transforms.Lambda(lambda t: t.permute(1, 2, 0)), # CHW to HWC
transforms.Lambda(lambda t: t * 255.0),
transforms.Lambda(lambda t: t.numpy().astype(np.uint8)),
transforms.ToPILImage(),
]
)
# Take first image of batch
if len(image.shape) == 4:
image = image[0, :, :, :]
plt.imshow(reverse_transforms(image))