@@ -96,6 +96,7 @@ def __init__(self, path_content, path_style, random_crop=True):
96
96
else :
97
97
self .transform = transforms .Compose (
98
98
[
99
+ transforms .CenterCrop ((256 ,256 )),
99
100
transforms .ToTensor (),
100
101
transforms .Normalize (mean = [0.485 , 0.456 , 0.406 ],
101
102
std = [0.229 , 0.224 , 0.225 ])
@@ -221,23 +222,28 @@ def AdaINLayer(self, x, y):
221
222
222
223
output : the result of AdaIN operation
223
224
"""
224
- style_mean , style_std = self .calc_mean_std (y )
225
- content_mean , content_std = self .calc_mean_std (x )
225
+ style_mu , style_sigma = self .calc_mean_std (y )
226
+ content_mu , content_sigma = self .calc_mean_std (x )
226
227
227
- normalized_feat = (x - content_mean .expand (size )) / content_std .expand ([Bx , Cx , Hx , Wx ])
228
- output = normalized_feat * style_std .expand (size ) + style_mean .expand ([Bx , Cx , Hx , Wx ])
228
+ normalized_feat = (x - content_mu .expand ([ Bx , Cx , Hx , Wx ] )) / content_sigma .expand ([Bx , Cx , Hx , Wx ])
229
+ output = normalized_feat * style_sigma .expand ([ Bx , Cx , Hx , Wx ] ) + style_mu .expand ([Bx , Cx , Hx , Wx ])
229
230
230
231
return output
231
232
232
- def encode_with_intermediate (self , input ):
233
- results = [input ]
234
- for i in range (4 ):
235
- func = getattr (self , 'enc_{:d}' .format (i + 1 ))
236
- results .append (func (results [- 1 ]))
233
+ def encode_with_intermediate (self , x ):
234
+ results = [x ]
235
+ h = self .enc_1 (x )
236
+ results .append (h )
237
+ h = self .enc_2 (h )
238
+ results .append (h )
239
+ h = self .enc_3 (h )
240
+ results .append (h )
241
+ h = self .enc_4 (h )
242
+ results .append (h )
237
243
return results [1 :]
238
244
239
- def encode (self , input ):
240
- h = self .enc_1 (h )
245
+ def encode (self , x ):
246
+ h = self .enc_1 (x )
241
247
h = self .enc_2 (h )
242
248
h = self .enc_3 (h )
243
249
h = self .enc_4 (h )
@@ -253,18 +259,17 @@ def calc_mean_std(self, feat, eps=1e-5):
253
259
feat_mean = feat .view (N , C , - 1 ).mean (dim = 2 ).view (N , C , 1 , 1 )
254
260
return feat_mean , feat_std
255
261
256
- def calc_content_loss (self , input , target ):
262
+ def calc_loss (self , input , target , loss_type ):
257
263
assert (input .size () == target .size ())
258
264
assert (target .requires_grad is False )
259
- return self .mse_loss (input , target )
260
-
261
- def calc_style_loss (self , input , target ):
262
- assert (input .size () == target .size ())
263
- assert (target .requires_grad is False )
264
- input_mean , input_std = self .calc_mean_std (input )
265
- target_mean , target_std = self .calc_mean_std (target )
266
- return self .mse_loss (input_mean , target_mean ) + \
267
- self .mse_loss (input_std , target_std )
265
+ assert loss_type == 'content' or loss_type == 'style'
266
+ if loss_type == 'content' :
267
+ return self .mse_loss (input , target )
268
+ else :
269
+ input_mean , input_std = self .calc_mean_std (input )
270
+ target_mean , target_std = self .calc_mean_std (target )
271
+ return self .mse_loss (input_mean , target_mean ) + \
272
+ self .mse_loss (input_std , target_std )
268
273
269
274
def forward (self , x , y ):
270
275
B , C , H , W = x .shape
@@ -281,7 +286,6 @@ def forward(self, x, y):
281
286
t = self .AdaINLayer (content_feat , style_feats [- 1 ])
282
287
if not self .training :
283
288
img_result = []
284
- #for alpha in np.arange(0,1.1,0.25):
285
289
alpha = 0.
286
290
a = alpha * t + (1 - alpha ) * content_feat
287
291
img_result .append (self .decoder (a ))
@@ -303,10 +307,10 @@ def forward(self, x, y):
303
307
img_result = self .decoder (t )
304
308
g_t_feats = self .encode_with_intermediate (img_result )
305
309
306
- loss_c = self .calc_content_loss (g_t_feats [- 1 ], t )
307
- loss_s = self .calc_style_loss (g_t_feats [0 ], style_feats [0 ])
310
+ loss_c = self .calc_loss (g_t_feats [- 1 ], t , 'content' )
311
+ loss_s = self .calc_loss (g_t_feats [0 ], style_feats [0 ], 'style' )
308
312
for i in range (1 , 4 ):
309
- loss_s += self .calc_style_loss (g_t_feats [i ], style_feats [i ])
313
+ loss_s += self .calc_loss (g_t_feats [i ], style_feats [i ], 'style' )
310
314
311
315
loss = loss_c + self .w_style * loss_s
312
316
@@ -380,8 +384,6 @@ def train():
380
384
381
385
optimizer .zero_grad ()
382
386
383
- #ipdb.set_trace()
384
-
385
387
loss , img_result = net (img_con , img_sty )
386
388
387
389
loss = torch .mean (loss )
0 commit comments