@@ -209,8 +209,7 @@ def train(self):
209
209
valid_x_A1 , valid_x_B1 = torch .Tensor (np .load (self .config .dataset_A2 + '_A.npy' )), torch .Tensor (np .load (self .config .dataset_A2 + '_B.npy' ))
210
210
valid_x_A1 , valid_x_B1 = self ._get_variable (valid_x_A1 ), self ._get_variable (valid_x_B1 )
211
211
except :
212
- print ('Cannot load validation file. Validation data not created' )
213
- assert 1
212
+ raise Exception ('Cannot load validation file. Validation data not created' )
214
213
'''
215
214
x_A1 = A_loader.next()
216
215
x_A2 = A1_loader.next()
@@ -255,10 +254,6 @@ def train(self):
255
254
256
255
x_A1 , x_B1 = x_A1 ['image' ], x_A1 ['edges' ]
257
256
x_A2 , x_B2 = x_A2 ['image' ], x_A2 ['edges' ]
258
- #writer.add_image('x_A1', x_A1[:16],step)
259
- #writer.add_image('x_A2', x_A2[:16],step)
260
- #writer.add_image('x_B1', x_B1[:16],step)
261
- #writer.add_image('x_B2', x_B2[:16],step)
262
257
if x_A1 .size (0 ) != x_B1 .size (0 ) or x_A2 .size (0 ) != x_B2 .size (0 ) or x_A1 .size (0 ) != x_A2 .size (0 ):
263
258
print ("[!] Sampled dataset from A and B have different # of data. Try resampling..." )
264
259
continue
@@ -269,7 +264,23 @@ def train(self):
269
264
batch_size = x_A1 .size (0 )
270
265
real_tensor .data .resize_ (batch_size ).fill_ (real_label )
271
266
fake_tensor .data .resize_ (batch_size ).fill_ (fake_label )
272
-
267
+ """
268
+ ####### Debugging
269
+ ipdb.set_trace()
270
+ x_A1.data.resize_(x_A1.shape).fill_(real_label)
271
+ x_B1.data.resize_(x_B1.shape).fill_(real_label)
272
+ x_A2.data.resize_(x_A2.shape).fill_(real_label)
273
+ x_B2.data.resize_(x_B2.shape).fill_(real_label)
274
+
275
+ self.G.load_state_dict(torch.load('G_0.pth'))
276
+ self.F.load_state_dict(torch.load('F_0.pth'))
277
+
278
+ self.D_S.load_state_dict(torch.load('D_A2_0.pth'))
279
+ self.D_H.load_state_dict(torch.load('D_A1_0.pth'))
280
+
281
+ self.D_B1.load_state_dict(torch.load('D_B1_0.pth'))
282
+ self.D_B2.load_state_dict(torch.load('D_B2_0.pth'))
283
+ """
273
284
## Update Db network
274
285
#self.D_B1.zero_grad()
275
286
#self.D_B2.zero_grad()
@@ -478,11 +489,11 @@ def train(self):
478
489
step)
479
490
writer.add_scalars('loss_F', {'l_const_A':l_const_A, 'l_const_AB': l_const_AB}, step)
480
491
#'l_const_B':l_const_B,'l_const_AB':l_const_AB,'l_const_BA':l_const_BA}, step)
481
- writer.add_scalars('loss_D', {'l_d_A':l_d_A,'l_d_B':l_d_B}, step)
482
- writer.add_image('x_A1', x_A1, step)
483
- writer.add_image('x_B1', x_B1, step)
484
- writer.add_image('x_A2', x_A2, step)
485
- writer.add_image('x_B2', x_B2, step)'''
492
+ writer.add_scalars('loss_D', {'l_d_A':l_d_A,'l_d_B':l_d_B}, step)'''
493
+ # writer.add_image('x_A1', x_A1[:16] , step)
494
+ # writer.add_image('x_B1', x_B1[:16] , step)
495
+ # writer.add_image('x_A2', x_A2[:16] , step)
496
+ # writer.add_image('x_B2', x_B2[:16] , step)
486
497
# Discriminator loss
487
498
writer .add_scalar ('L_d_A' , l_dA , step )
488
499
writer .add_scalar ('L_d_B' , l_dB , step )
@@ -498,8 +509,11 @@ def train(self):
498
509
torch .save (self .G .state_dict (), '{}/G_{}.pth' .format (self .model_dir , step ))
499
510
torch .save (self .F .state_dict (), '{}/F_{}.pth' .format (self .model_dir , step ))
500
511
501
- torch .save (self .D_S .state_dict (), '{}/D_A_{}.pth' .format (self .model_dir , step ))
502
- torch .save (self .D_H .state_dict (), '{}/D_B_{}.pth' .format (self .model_dir , step ))
512
+ torch .save (self .D_H .state_dict (), '{}/D_A1_{}.pth' .format (self .model_dir , step ))
513
+ torch .save (self .D_S .state_dict (), '{}/D_A2_{}.pth' .format (self .model_dir , step ))
514
+
515
+ torch .save (self .D_B1 .state_dict (), '{}/D_B1_{}.pth' .format (self .model_dir , step ))
516
+ torch .save (self .D_B2 .state_dict (), '{}/D_B2_{}.pth' .format (self .model_dir , step ))
503
517
504
518
def generate_with_A (self , inputs , input_ref , path , idx = None , tf_board = True ):
505
519
x_AB = self .F (inputs )
@@ -524,7 +538,7 @@ def generate_with_A(self, inputs, input_ref, path, idx=None, tf_board=True):
524
538
writer .add_image ('x_A1valid' , inputs [:16 ], idx )
525
539
writer .add_image ('x_A1G' , x_ABA [:16 ], idx )
526
540
#writer.add_image('x_B1_2f', x_ABAf[:16], idx)
527
- writer .add_image ('x_A1_rec ' , x_ABAB [:16 ], idx )
541
+ writer .add_image ('x_A1rec ' , x_ABAB [:16 ], idx )
528
542
#writer.add_image('x_ABA', x_ABA, idx)
529
543
#vutils.save_image(x_ABA.data, x_ABA_path)
530
544
#print("[*] Samples saved: {}".format(x_ABA_path))
@@ -552,7 +566,7 @@ def generate_with_B(self, inputs, input_ref, path, idx=None, tf_board=True):
552
566
writer .add_image ('x_A2valid' , inputs [:16 ], idx )
553
567
writer .add_image ('x_A2G' , x_BAB [:16 ], idx )
554
568
#writer.add_image('x_B1_1f', x_ABAf[:16], idx)
555
- writer .add_image ('x_A2_rec ' , x_ABAB [:16 ], idx )
569
+ writer .add_image ('x_A2rec ' , x_ABAB [:16 ], idx )
556
570
#writer.add_image('x_BAB', x_BAB, idx)
557
571
#vutils.save_image(x_BAB.data, x_BAB_path)
558
572
#print("[*] Samples saved: {}".format(x_BAB_path))
@@ -573,9 +587,9 @@ def generate_infinitely(self, inputs, path, input_type, count=10, nrow=2, idx=No
573
587
574
588
def test (self ):
575
589
batch_size = self .config .sample_per_image
576
- x_A1 , x_B1 = torch .Tensor (np .load ('.. /valid_x_A1.npy' )), torch .Tensor (np .load ('. ./valid_x_B1.npy' ))
590
+ x_A1 , x_B1 = torch .Tensor (np .load ('./valid_x_A1.npy' )), torch .Tensor (np .load ('./valid_x_B1.npy' ))
577
591
x_A1 , x_B1 = self ._get_variable (x_A1 ), self ._get_variable (x_B1 )
578
- x_A2 , x_B2 = torch .Tensor (np .load ('.. /valid_x_A2.npy' )), torch .Tensor (np .load ('. ./valid_x_B2.npy' ))
592
+ x_A2 , x_B2 = torch .Tensor (np .load ('./valid_x_A2.npy' )), torch .Tensor (np .load ('./valid_x_B2.npy' ))
579
593
x_A2 , x_B2 = self ._get_variable (x_A2 ), self ._get_variable (x_B2 )
580
594
581
595
test_dir = os .path .join (self .model_dir , 'test' )
0 commit comments