forked from sanghoon/pytorch_imagenet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
show_imagenet.py
76 lines (64 loc) · 2.44 KB
/
show_imagenet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
#!/usr/bin/python
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
import numpy as np
import os
import argparse
MEAN_COEF = [0.485, 0.456, 0.406]
DIV_COEF = [0.229, 0.224, 0.225]
# functions to show an image
def imshow(img):
npimg = img.numpy() * np.array(DIV_COEF).reshape([3,1,1]) \
+ np.array(MEAN_COEF).reshape([3,1,1]) # Un-normalize
plt.imshow(np.transpose(npimg, (1, 2, 0)))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ImageNet Loader Test')
parser.add_argument('data', metavar='DIR',
help='path to dataset')
parser.add_argument('-j', '--workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('-b', '--batch-size', default=64, type=int,
metavar='N', help='mini-batch size (default: 64)')
args = parser.parse_args()
# Data loading code
traindir = os.path.join(args.data, 'train')
valdir = os.path.join(args.data, 'val')
normalize = transforms.Normalize(mean=MEAN_COEF,
std=DIV_COEF)
train_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(traindir, transforms.Compose([
transforms.RandomSizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=True,
num_workers=args.workers, pin_memory=True)
val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
])),
batch_size=args.batch_size, shuffle=False,
num_workers=args.workers, pin_memory=True)
# Sample training images
dataiter = iter(train_loader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
plt.show()
# Sample validation images
dataiter = iter(val_loader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
plt.show()