Skip to content

Commit 1a03c88

Browse files
author
kmbae
committed
Evaluation code
1 parent b6edeb7 commit 1a03c88

File tree

1 file changed

+124
-0
lines changed

1 file changed

+124
-0
lines changed

eval.py

+124
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
"""
2+
2018 Spring EE898
3+
Advanced Topics in Deep Learning
4+
for Robotics and Computer Vision
5+
6+
Programming Assignment 2
7+
Neural Style Transfer
8+
9+
Author : Jinsun Park ([email protected])
10+
11+
References
12+
[1] Gatys et al., "Image Style Transfer using Convolutional
13+
Neural Networks", CVPR 2016.
14+
[2] Huang and Belongie, "Arbitrary Style Transfer in Real-Time
15+
with Adaptive Instance Normalization", ICCV 2017.
16+
"""
17+
18+
from __future__ import absolute_import
19+
from __future__ import division
20+
from __future__ import print_function
21+
22+
import numpy as np
23+
import gc
24+
import visdom
25+
import os
26+
import time
27+
import numpy as np
28+
from os import listdir
29+
from PIL import Image
30+
from datetime import datetime
31+
import ipdb
32+
import torch
33+
import torch.nn as nn
34+
import torch.optim as optim
35+
from torch.nn import functional as F
36+
from torchvision import utils, transforms, models
37+
from torch.autograd import Variable
38+
from torch.utils.data import Dataset, DataLoader
39+
from train import *
40+
41+
42+
# Some utilities
43+
44+
45+
46+
"""
47+
Task 2. Complete training code.
48+
49+
Following skeleton code assumes that you have multiple GPUs
50+
You can freely change any of parameters
51+
"""
52+
def test():
53+
gc.disable()
54+
55+
# Parameters
56+
path_snapshot = 'snapshots'
57+
path_content = 'dataset/test/content'
58+
path_style = 'dataset/test/style'
59+
60+
if not os.path.exists(path_snapshot):
61+
os.makedirs(path_snapshot)
62+
63+
batch_size = 1
64+
weight_decay = 1.0e-5
65+
num_epoch = 600
66+
lr_init = 0.0001#0.001
67+
lr_decay_step = num_epoch/2
68+
momentum = 0.9
69+
#device_ids = [0, 1, 2]
70+
w_style = 10
71+
alpha = 1
72+
disp_step = 1
73+
74+
# Data loader
75+
dm = DataManager(path_content, path_style, random_crop=True)
76+
dl = DataLoader(dm, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=False)
77+
78+
num_train = dm.num
79+
num_batch = np.ceil(num_train / batch_size)
80+
loss_train_avg = np.zeros(num_epoch)
81+
82+
net = StyleTransferNet(w_style, alpha)
83+
net = nn.DataParallel(net.cuda(), device_ids=range(torch.cuda.device_count()))
84+
85+
# Load model
86+
state_dict = torch.load('snapshots/epoch_000501.pth')
87+
net.load_state_dict(state_dict)
88+
89+
# Start training
90+
net.eval()
91+
running_loss_train = 0
92+
93+
for i, data in enumerate(dl, 0):
94+
img_con = data['content']
95+
img_sty = data['style']
96+
97+
img_con = Variable(img_con, requires_grad=False).cuda()
98+
img_sty = Variable(img_sty, requires_grad=False).cuda()
99+
100+
img_result = net(img_con, img_sty)
101+
img_result.insert(0, img_con)
102+
img_result.append(img_sty)
103+
img_cat = torch.cat(img_result, dim=3)
104+
img_cat = torch.unbind(img_cat, dim=0)
105+
img_cat = torch.cat(img_cat, dim=1)
106+
img_cat = dm.restore(img_cat.data.cpu())
107+
output_img = torch.clamp(img_cat, 0, 1)
108+
109+
tt=transforms.ToPILImage()(output_img)
110+
tt.save('test_out/{}.png'.format(i))
111+
112+
if (i+1)%disp_step==0:
113+
print('Testing {}/{} images'.format(i,len(dl)))
114+
115+
116+
gc_collected = gc.collect()
117+
gc.disable()
118+
119+
print('Testing finished.')
120+
121+
122+
123+
if __name__ == '__main__':
124+
test()

0 commit comments

Comments
 (0)