-
Notifications
You must be signed in to change notification settings - Fork 0
/
chalearn_main.py
94 lines (83 loc) · 2.9 KB
/
chalearn_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
from torch import optim
from torch.utils.data import DataLoader
from torch.utils.data import sampler
from torchvision import transforms
from utils.age_detection_utils import check_result
from datasets.chalearn_training_dataset import ChaLearnTrainingDataset
from loss_funcs.soft_argmax import SoftArgmaxLoss
from trainer import Trainer
BATCH_SIZE = 400
DATA_LOADER_NUM_WORKERS = 10
MODEL_PATH = 'models/model_imdb_wiki_norm_0001_epoch20.pt'
def main():
if torch.cuda.is_available():
device = torch.device('cuda')
else:
device = torch.device('cpu')
print(f'Using device {device}')
# Load the pretrained RESNET-18 model.
model = torch.load(MODEL_PATH).to(device=device)
loss_func = SoftArgmaxLoss().to(device=device)
# dtype depends on the loss function.
dtype = torch.cuda.FloatTensor
optimizer = optim.Adam(model.parameters(), lr=1e-3)
loader_train, loader_val, loader_test = _split_data()
model_trainer = Trainer(
model, loss_func, dtype, optimizer, device,
loader_train, loader_val, loader_test, check_result,
MODEL_PATH, num_epochs=5, print_every=100
)
model_trainer.train()
model_trainer.test()
def _split_data():
train_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(),
transforms.ToTensor(),
transforms.Normalize([0.5797703, 0.43427974, 0.38307136], [0.25409877, 0.22383073, 0.21819368]),
])
val_transform = transforms.Compose([
transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.5797703, 0.43427974, 0.38307136], [0.25409877, 0.22383073, 0.21819368]),
])
train_dataset = ChaLearnTrainingDataset(
['ChaLearn/images/train_1', 'ChaLearn/images/train_2'],
'ChaLearn/gt/train_gt.csv',
train_transform,
)
val_dataset = ChaLearnTrainingDataset(
['ChaLearn/images/valid'],
'ChaLearn/gt/valid_gt.csv',
val_transform,
)
test_dataset = ChaLearnTrainingDataset(
['ChaLearn/images/test_1', 'ChaLearn/images/test_2'],
'ChaLearn/gt/test_gt.csv',
val_transform,
)
loader_train = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
num_workers=DATA_LOADER_NUM_WORKERS,
sampler=sampler.RandomSampler(train_dataset),
)
loader_val = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
num_workers=DATA_LOADER_NUM_WORKERS,
sampler=sampler.RandomSampler(val_dataset),
)
loader_test = DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
num_workers=DATA_LOADER_NUM_WORKERS,
sampler=sampler.RandomSampler(test_dataset),
)
return loader_train, loader_val, loader_test
if __name__ == '__main__':
main()