Skip to content

Commit 7f0f1b3

Browse files
author
kmbae
committed
backup
1 parent 1a03c88 commit 7f0f1b3

File tree

2 files changed

+31
-29
lines changed

2 files changed

+31
-29
lines changed

eval.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def test():
7272
disp_step = 1
7373

7474
# Data loader
75-
dm = DataManager(path_content, path_style, random_crop=True)
76-
dl = DataLoader(dm, batch_size=batch_size, shuffle=True, num_workers=4, drop_last=False)
75+
dm = DataManager(path_content, path_style, random_crop=False)
76+
dl = DataLoader(dm, batch_size=batch_size, shuffle=False, num_workers=4, drop_last=False)
7777

7878
num_train = dm.num
7979
num_batch = np.ceil(num_train / batch_size)

train.py

+29-27
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def __init__(self, path_content, path_style, random_crop=True):
9696
else:
9797
self.transform = transforms.Compose(
9898
[
99+
transforms.CenterCrop((256,256)),
99100
transforms.ToTensor(),
100101
transforms.Normalize(mean=[0.485, 0.456, 0.406],
101102
std=[0.229, 0.224, 0.225])
@@ -221,23 +222,28 @@ def AdaINLayer(self, x, y):
221222
222223
output : the result of AdaIN operation
223224
"""
224-
style_mean, style_std = self.calc_mean_std(y)
225-
content_mean, content_std = self.calc_mean_std(x)
225+
style_mu, style_sigma = self.calc_mean_std(y)
226+
content_mu, content_sigma = self.calc_mean_std(x)
226227

227-
normalized_feat = (x - content_mean.expand(size)) / content_std.expand([Bx, Cx, Hx, Wx])
228-
output = normalized_feat * style_std.expand(size) + style_mean.expand([Bx, Cx, Hx, Wx])
228+
normalized_feat = (x - content_mu.expand([Bx, Cx, Hx, Wx])) / content_sigma.expand([Bx, Cx, Hx, Wx])
229+
output = normalized_feat * style_sigma.expand([Bx, Cx, Hx, Wx]) + style_mu.expand([Bx, Cx, Hx, Wx])
229230

230231
return output
231232

232-
def encode_with_intermediate(self, input):
233-
results = [input]
234-
for i in range(4):
235-
func = getattr(self, 'enc_{:d}'.format(i + 1))
236-
results.append(func(results[-1]))
233+
def encode_with_intermediate(self, x):
234+
results = [x]
235+
h = self.enc_1(x)
236+
results.append(h)
237+
h = self.enc_2(h)
238+
results.append(h)
239+
h = self.enc_3(h)
240+
results.append(h)
241+
h = self.enc_4(h)
242+
results.append(h)
237243
return results[1:]
238244

239-
def encode(self, input):
240-
h = self.enc_1(h)
245+
def encode(self, x):
246+
h = self.enc_1(x)
241247
h = self.enc_2(h)
242248
h = self.enc_3(h)
243249
h = self.enc_4(h)
@@ -253,18 +259,17 @@ def calc_mean_std(self, feat, eps=1e-5):
253259
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
254260
return feat_mean, feat_std
255261

256-
def calc_content_loss(self, input, target):
262+
def calc_loss(self, input, target, loss_type):
257263
assert (input.size() == target.size())
258264
assert (target.requires_grad is False)
259-
return self.mse_loss(input, target)
260-
261-
def calc_style_loss(self, input, target):
262-
assert (input.size() == target.size())
263-
assert (target.requires_grad is False)
264-
input_mean, input_std = self.calc_mean_std(input)
265-
target_mean, target_std = self.calc_mean_std(target)
266-
return self.mse_loss(input_mean, target_mean) + \
267-
self.mse_loss(input_std, target_std)
265+
assert loss_type=='content' or loss_type=='style'
266+
if loss_type=='content':
267+
return self.mse_loss(input, target)
268+
else:
269+
input_mean, input_std = self.calc_mean_std(input)
270+
target_mean, target_std = self.calc_mean_std(target)
271+
return self.mse_loss(input_mean, target_mean) + \
272+
self.mse_loss(input_std, target_std)
268273

269274
def forward(self, x, y):
270275
B, C, H, W = x.shape
@@ -281,7 +286,6 @@ def forward(self, x, y):
281286
t = self.AdaINLayer(content_feat, style_feats[-1])
282287
if not self.training:
283288
img_result = []
284-
#for alpha in np.arange(0,1.1,0.25):
285289
alpha = 0.
286290
a = alpha * t + (1 - alpha) * content_feat
287291
img_result.append(self.decoder(a))
@@ -303,10 +307,10 @@ def forward(self, x, y):
303307
img_result = self.decoder(t)
304308
g_t_feats = self.encode_with_intermediate(img_result)
305309

306-
loss_c = self.calc_content_loss(g_t_feats[-1], t)
307-
loss_s = self.calc_style_loss(g_t_feats[0], style_feats[0])
310+
loss_c = self.calc_loss(g_t_feats[-1], t, 'content')
311+
loss_s = self.calc_loss(g_t_feats[0], style_feats[0], 'style')
308312
for i in range(1, 4):
309-
loss_s += self.calc_style_loss(g_t_feats[i], style_feats[i])
313+
loss_s += self.calc_loss(g_t_feats[i], style_feats[i], 'style')
310314

311315
loss = loss_c + self.w_style*loss_s
312316

@@ -380,8 +384,6 @@ def train():
380384

381385
optimizer.zero_grad()
382386

383-
#ipdb.set_trace()
384-
385387
loss, img_result = net(img_con, img_sty)
386388

387389
loss = torch.mean(loss)

0 commit comments

Comments
 (0)