Skip to content

Commit 36e7613

Browse files
author
kmbae
committed
Rebuttal version
1 parent e0c5306 commit 36e7613

File tree

5 files changed

+144
-22
lines changed

5 files changed

+144
-22
lines changed

config.py

+2
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ def add_argument_group(name):
2828
data_arg.add_argument('--a_grayscale', type=str2bool, default=False)
2929
data_arg.add_argument('--b_grayscale', type=str2bool, default=False)
3030
data_arg.add_argument('--num_worker', type=int, default=12)
31+
data_arg.add_argument('--dataset_A1', type=str, default='../data/edges2handbags')
32+
data_arg.add_argument('--dataset_A2', type=str, default='../data/edges2shoes')
3133

3234
# Training / test parameters
3335
train_arg = add_argument_group('Training')

data_loader.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import torch
88
from torchvision import transforms
9+
#from skimage import feature
10+
#from skimage.color import rgb2gray
911

1012
PIX2PIX_DATASETS = [
1113
'facades', 'cityscapes', 'maps', 'edges2shoes', 'edges2handbags']
@@ -67,9 +69,12 @@ def __init__(self, root, scale_size, data_type, skip_pix2pix_processing=False):
6769

6870
def __getitem__(self, index):
6971
image = Image.open(self.paths[index]).convert('RGB')
72+
edges = Image.open(self.paths[index].replace('/A','/B1')).convert('RGB')
73+
#edges = image.filter(ImageFilter.FIND_EDGES)
74+
7075
#if self.data_type=='B':
7176
# image = image.filter(ImageFilter.MinFilter(3))
72-
return self.transform(image)
77+
return {'image':self.transform(image), 'edges':self.transform(edges)}
7378

7479
def __len__(self):
7580
return len(self.paths)

main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ def main(config):
2323
batch_size = config.sample_per_image
2424

2525
a_data_loader, b_data_loader = get_loader(
26-
data_path, batch_size, config.input_scale_size,
26+
config.dataset_A1, batch_size, config.input_scale_size,
2727
config.num_worker, config.skip_pix2pix_processing)
2828

2929
a1_data_loader, b1_data_loader = get_loader(
30-
'../data/edges2handbags', batch_size, config.input_scale_size,
30+
config.dataset_A2, batch_size, config.input_scale_size,
3131
config.num_worker, config.skip_pix2pix_processing)
3232

3333
trainer = Trainer(config, a_data_loader, b_data_loader, a1_data_loader, b1_data_loader)

test.py

+76
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
from models import *
2+
import torch
3+
from torch import nn
4+
import glob
5+
from torchvision import transforms
6+
import tqdm
7+
from config import get_config
8+
from PIL import Image
9+
import argparse
10+
import os
11+
12+
# Usage example
13+
# run test.py --load=../logs/edges2shoes_2018-09-07_17-05-19 --iter=10000 --con=../validation_image/handbag_picking --sty=../validation_image/valid_x_A1
14+
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument("--load", type=str, help="Saved file dir")
17+
parser.add_argument("--iter", type=str, help="Number of iteration")
18+
parser.add_argument("--con", type=str, help="Dir of content images")
19+
parser.add_argument("--sty", type=str, help="Dir of style images")
20+
args = parser.parse_args()
21+
22+
scale_size = 64
23+
transform = transforms.Compose([
24+
transforms.Scale(scale_size),
25+
transforms.ToTensor(),
26+
#transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
27+
])
28+
ToPILImage = transforms.ToPILImage()
29+
30+
if __name__=='__main__':
31+
config, unparsed = get_config()
32+
conv_dims, deconv_dims = [64, 128, 256, 512], [256, 128, 64]
33+
a_channel = 3
34+
b_channel = 3
35+
num_gpu = torch.cuda.device_count()
36+
G = GeneratorCNN_g(a_channel+b_channel, b_channel, conv_dims, deconv_dims, num_gpu)
37+
F = GeneratorCNN(a_channel, a_channel, conv_dims, deconv_dims, num_gpu)
38+
39+
G = nn.DataParallel(G.cuda(),device_ids=range(torch.cuda.device_count()))
40+
F = nn.DataParallel(F.cuda(),device_ids=range(torch.cuda.device_count()))
41+
print('Loading model')
42+
G.load_state_dict(torch.load(args.load + '/G_{}.pth'.format(args.iter)))
43+
F.load_state_dict(torch.load(args.load + '/F_{}.pth'.format(args.iter)))
44+
print('Model loaded')
45+
46+
if not os.path.exists('./results'):
47+
os.mkdir('./results')
48+
49+
list_con = os.listdir(args.con)
50+
list_sty = os.listdir(args.sty)
51+
img_con = []
52+
img_sty = []
53+
for i in (list_con):
54+
img_con.append(transform(Image.open(args.con+'/'+i)))
55+
for j in (list_sty):
56+
img_sty.append(transform(Image.open(args.sty+'/'+j)))
57+
58+
with torch.no_grad():
59+
G.eval()
60+
F.eval()
61+
for i, con_tmp in tqdm.tqdm(enumerate(img_con)):
62+
con_tmp = torch.unsqueeze(con_tmp.cuda(),0)
63+
if not os.path.exists('./results/{}'.format(list_con[i].split('.')[0])):
64+
os.mkdir('./results/{}'.format(list_con[i].split('.')[0]))
65+
for j, sty_tmp in enumerate(img_sty):
66+
sty_tmp = torch.unsqueeze(sty_tmp.cuda(),0)
67+
img_out = G(F(con_tmp), sty_tmp)
68+
img_out = ToPILImage(img_out.data[0].cpu())
69+
img_out.save('./results/{}/{}_{}.jpg'.format(list_con[i].split('.')[0],
70+
list_con[i].split('.')[0],list_sty[j].split('.')[0]))
71+
72+
73+
74+
75+
76+

trainer.py

+58-19
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import ipdb
1818
from tensorboardX import SummaryWriter
1919
from datetime import datetime
20-
20+
import torchvision
2121
tmp = datetime.now()
2222

2323
writer = SummaryWriter('../runs/' + str(tmp))
@@ -194,13 +194,35 @@ def train(self):
194194
self.G.parameters(),
195195
lr=self.lr, betas=(self.beta1, self.beta2))
196196

197-
A_loader, B_loader = iter(self.a_data_loader), iter(self.b_data_loader)
198-
valid_x_A, valid_x_B = torch.Tensor(np.load('../valid_x_A1.npy')), torch.Tensor(np.load('../valid_x_B1.npy'))
199-
valid_x_A, valid_x_B = self._get_variable(valid_x_A), self._get_variable(valid_x_B)
197+
#A_loader, B_loader = iter(self.a_data_loader), iter(self.b_data_loader)
198+
A_loader = iter(self.a_data_loader)
199+
#valid_x_A, valid_x_B = torch.Tensor(np.load('../valid_x_A2.npy')), torch.Tensor(np.load('../valid_x_B2.npy'))
200200
#self._get_variable(A_loader.next()), self._get_variable(B_loader.next())
201-
A1_loader, B1_loader = iter(self.a1_data_loader), iter(self.b1_data_loader)
202-
valid_x_A1, valid_x_B1=torch.Tensor(np.load('../valid_x_A2.npy')), torch.Tensor(np.load('../valid_x_B2.npy'))
203-
valid_x_A1, valid_x_B1 = self._get_variable(valid_x_A1), self._get_variable(valid_x_B1)
201+
#A1_loader, B1_loader = iter(self.a1_data_loader), iter(self.b1_data_loader)
202+
A1_loader = iter(self.a1_data_loader)
203+
#valid_x_A1, valid_x_B1=torch.Tensor(np.load('../valid_x_A2_chair.npy')), torch.Tensor(np.load('../valid_x_B2_chair.npy'))
204+
try:
205+
valid_x_A, valid_x_B = torch.Tensor(np.load(self.config.dataset_A1+'_A.npy')), torch.Tensor(np.load(self.config.dataset_A1+'_B.npy'))
206+
valid_x_A, valid_x_B = self._get_variable(valid_x_A), self._get_variable(valid_x_B)
207+
valid_x_A1, valid_x_B1=torch.Tensor(np.load(self.config.dataset_A2+'_A.npy')), torch.Tensor(np.load(self.config.dataset_A2+'_B.npy'))
208+
valid_x_A1, valid_x_B1 = self._get_variable(valid_x_A1), self._get_variable(valid_x_B1)
209+
except:
210+
print('Cannot load validation file. Creating new validation file')
211+
x_A1 = A_loader.next()
212+
x_A2 = A1_loader.next()
213+
x_A1, x_B1 = x_A1['image'], x_A1['edges']
214+
x_A2, x_B2 = x_A2['image'], x_A2['edges']
215+
np.save(self.config.dataset_A1+'_A.npy',np.array(x_A1))
216+
np.save(self.config.dataset_A1+'_B.npy',np.array(x_B1))
217+
np.save(self.config.dataset_A2+'_A.npy',np.array(x_A2))
218+
np.save(self.config.dataset_A2+'_B.npy',np.array(x_B2))
219+
220+
valid_x_A, valid_x_B = torch.Tensor(np.load(self.config.dataset_A1+'_A.npy')), torch.Tensor(np.load(self.config.dataset_A1+'_B.npy'))
221+
valid_x_A, valid_x_B = self._get_variable(valid_x_A), self._get_variable(valid_x_B)
222+
valid_x_A1, valid_x_B1=torch.Tensor(np.load(self.config.dataset_A2+'_A.npy')), torch.Tensor(np.load(self.config.dataset_A2+'_B.npy'))
223+
valid_x_A1, valid_x_B1 = self._get_variable(valid_x_A1), self._get_variable(valid_x_B1)
224+
225+
204226
#self._get_variable(A1_loader.next()), self._get_variable(B1_loader.next())
205227
#ipdb.set_trace()
206228

@@ -212,20 +234,23 @@ def train(self):
212234
for step in trange(self.start_step, self.max_step):
213235
try:
214236
x_A1 = A_loader.next()
215-
x_B1 = B_loader.next()
237+
#x_B1 = B_loader.next()
216238
except StopIteration:
217239
A_loader = iter(self.a_data_loader)
218-
B_loader = iter(self.b_data_loader)
240+
#B_loader = iter(self.b_data_loader)
219241
x_A1 = A_loader.next()
220-
x_B1 = B_loader.next()
242+
#x_B1 = B_loader.next()
221243
try:
222244
x_A2 = A1_loader.next()
223-
x_B2 = B1_loader.next()
245+
#x_B2 = B1_loader.next()
224246
except StopIteration:
225247
A1_loader = iter(self.a1_data_loader)
226-
B1_loader = iter(self.b1_data_loader)
248+
#B1_loader = iter(self.b1_data_loader)
227249
x_A2 = A1_loader.next()
228-
x_B2 = B1_loader.next()
250+
#x_B2 = B1_loader.next()
251+
252+
x_A1, x_B1 = x_A1['image'], x_A1['edges']
253+
x_A2, x_B2 = x_A2['image'], x_A2['edges']
229254
if x_A1.size(0) != x_B1.size(0) or x_A2.size(0) != x_B2.size(0) or x_A1.size(0) != x_A2.size(0):
230255
print("[!] Sampled dataset from A and B have different # of data. Try resampling...")
231256
continue
@@ -422,17 +447,17 @@ def train(self):
422447

423448
self.generate_with_A(valid_x_A, valid_x_A1, self.model_dir, idx=step)
424449
self.generate_with_B(valid_x_A1, valid_x_A, self.model_dir, idx=step)
425-
writer.add_scalars('loss_G', {'l_g':l_g,'l_gan_A':l_gan_A,'l_const_A':l_const_A,
426-
'l_f':l_f, 'l_const_AB': l_const_AB},
450+
writer.add_scalars('loss_G', {'l_g':l_g,'l_gan_A':l_gan_A,
451+
'l_f':l_f},
427452
step)
453+
writer.add_scalars('loss_F', {'l_const_A':l_const_A, 'l_const_AB': l_const_AB}, step)
428454
#'l_const_B':l_const_B,'l_const_AB':l_const_AB,'l_const_BA':l_const_BA}, step)
429455
writer.add_scalars('loss_D', {'l_d_A':l_d_A,'l_d_B':l_d_B}, step)
430-
431-
if step % self.save_step == self.save_step - 1:
456+
if step % self.save_step == 0:
432457
print("[*] Save models to {}...".format(self.model_dir))
433458

434-
torch.save(self.G.state_dict(), '{}/G_AB_{}.pth'.format(self.model_dir, step))
435-
torch.save(self.F.state_dict(), '{}/G_BA_{}.pth'.format(self.model_dir, step))
459+
torch.save(self.G.state_dict(), '{}/G_{}.pth'.format(self.model_dir, step))
460+
torch.save(self.F.state_dict(), '{}/F_{}.pth'.format(self.model_dir, step))
436461

437462
torch.save(self.D_S.state_dict(), '{}/D_A_{}.pth'.format(self.model_dir, step))
438463
torch.save(self.D_H.state_dict(), '{}/D_B_{}.pth'.format(self.model_dir, step))
@@ -447,6 +472,13 @@ def generate_with_A(self, inputs, input_ref, path, idx=None, tf_board=True):
447472
#x_ABA_path = '{}/{}_x_ABA.png'.format(path, idx)
448473

449474
vutils.save_image(x_ABA.data, x_AB_path)
475+
if not os.path.isdir('{}/{}_A1'.format(path, idx)):
476+
os.makedirs('{}/{}_A1'.format(path, idx))
477+
for i in range(x_ABA.size(0)):
478+
tmp = x_ABA[i].detach().cpu()
479+
tmp = torchvision.transforms.ToPILImage()(tmp)
480+
tmp.save('{}/{}_A1/{}.png'.format(path, idx, i))
481+
450482
print("[*] Samples saved: {}".format(x_AB_path))
451483
if tf_board:
452484
writer.add_image('x_A1f', x_AB[:16], idx)
@@ -468,6 +500,13 @@ def generate_with_B(self, inputs, input_ref, path, idx=None, tf_board=True):
468500
#x_BAB_path = '{}/{}_x_BAB.png'.format(path, idx)
469501

470502
vutils.save_image(x_BAB.data, x_BA_path)
503+
if not os.path.isdir('{}/{}_A2'.format(path, idx)):
504+
os.makedirs('{}/{}_A2'.format(path, idx))
505+
for i in range(x_BAB.size(0)):
506+
tmp = x_BAB[i].detach().cpu()
507+
tmp = torchvision.transforms.ToPILImage()(tmp)
508+
tmp.save('{}/{}_A2/{}.png'.format(path, idx, i))
509+
471510
print("[*] Samples saved: {}".format(x_BA_path))
472511
if tf_board:
473512
writer.add_image('x_A2f', x_BA[:16], idx)

0 commit comments

Comments
 (0)