17
17
import ipdb
18
18
from tensorboardX import SummaryWriter
19
19
from datetime import datetime
20
-
20
+ import torchvision
21
21
tmp = datetime .now ()
22
22
23
23
writer = SummaryWriter ('../runs/' + str (tmp ))
@@ -194,13 +194,35 @@ def train(self):
194
194
self .G .parameters (),
195
195
lr = self .lr , betas = (self .beta1 , self .beta2 ))
196
196
197
- A_loader , B_loader = iter (self .a_data_loader ), iter (self .b_data_loader )
198
- valid_x_A , valid_x_B = torch . Tensor ( np . load ( '../valid_x_A1.npy' )), torch . Tensor ( np . load ( '../valid_x_B1.npy' ) )
199
- valid_x_A , valid_x_B = self . _get_variable ( valid_x_A ), self . _get_variable ( valid_x_B )
197
+ # A_loader, B_loader = iter(self.a_data_loader), iter(self.b_data_loader)
198
+ A_loader = iter ( self . a_data_loader )
199
+ # valid_x_A, valid_x_B = torch.Tensor(np.load('../valid_x_A2.npy')), torch.Tensor(np.load('../valid_x_B2.npy') )
200
200
#self._get_variable(A_loader.next()), self._get_variable(B_loader.next())
201
- A1_loader , B1_loader = iter (self .a1_data_loader ), iter (self .b1_data_loader )
202
- valid_x_A1 , valid_x_B1 = torch .Tensor (np .load ('../valid_x_A2.npy' )), torch .Tensor (np .load ('../valid_x_B2.npy' ))
203
- valid_x_A1 , valid_x_B1 = self ._get_variable (valid_x_A1 ), self ._get_variable (valid_x_B1 )
201
+ #A1_loader, B1_loader = iter(self.a1_data_loader), iter(self.b1_data_loader)
202
+ A1_loader = iter (self .a1_data_loader )
203
+ #valid_x_A1, valid_x_B1=torch.Tensor(np.load('../valid_x_A2_chair.npy')), torch.Tensor(np.load('../valid_x_B2_chair.npy'))
204
+ try :
205
+ valid_x_A , valid_x_B = torch .Tensor (np .load (self .config .dataset_A1 + '_A.npy' )), torch .Tensor (np .load (self .config .dataset_A1 + '_B.npy' ))
206
+ valid_x_A , valid_x_B = self ._get_variable (valid_x_A ), self ._get_variable (valid_x_B )
207
+ valid_x_A1 , valid_x_B1 = torch .Tensor (np .load (self .config .dataset_A2 + '_A.npy' )), torch .Tensor (np .load (self .config .dataset_A2 + '_B.npy' ))
208
+ valid_x_A1 , valid_x_B1 = self ._get_variable (valid_x_A1 ), self ._get_variable (valid_x_B1 )
209
+ except :
210
+ print ('Cannot load validation file. Creating new validation file' )
211
+ x_A1 = A_loader .next ()
212
+ x_A2 = A1_loader .next ()
213
+ x_A1 , x_B1 = x_A1 ['image' ], x_A1 ['edges' ]
214
+ x_A2 , x_B2 = x_A2 ['image' ], x_A2 ['edges' ]
215
+ np .save (self .config .dataset_A1 + '_A.npy' ,np .array (x_A1 ))
216
+ np .save (self .config .dataset_A1 + '_B.npy' ,np .array (x_B1 ))
217
+ np .save (self .config .dataset_A2 + '_A.npy' ,np .array (x_A2 ))
218
+ np .save (self .config .dataset_A2 + '_B.npy' ,np .array (x_B2 ))
219
+
220
+ valid_x_A , valid_x_B = torch .Tensor (np .load (self .config .dataset_A1 + '_A.npy' )), torch .Tensor (np .load (self .config .dataset_A1 + '_B.npy' ))
221
+ valid_x_A , valid_x_B = self ._get_variable (valid_x_A ), self ._get_variable (valid_x_B )
222
+ valid_x_A1 , valid_x_B1 = torch .Tensor (np .load (self .config .dataset_A2 + '_A.npy' )), torch .Tensor (np .load (self .config .dataset_A2 + '_B.npy' ))
223
+ valid_x_A1 , valid_x_B1 = self ._get_variable (valid_x_A1 ), self ._get_variable (valid_x_B1 )
224
+
225
+
204
226
#self._get_variable(A1_loader.next()), self._get_variable(B1_loader.next())
205
227
#ipdb.set_trace()
206
228
@@ -212,20 +234,23 @@ def train(self):
212
234
for step in trange (self .start_step , self .max_step ):
213
235
try :
214
236
x_A1 = A_loader .next ()
215
- x_B1 = B_loader .next ()
237
+ # x_B1 = B_loader.next()
216
238
except StopIteration :
217
239
A_loader = iter (self .a_data_loader )
218
- B_loader = iter (self .b_data_loader )
240
+ # B_loader = iter(self.b_data_loader)
219
241
x_A1 = A_loader .next ()
220
- x_B1 = B_loader .next ()
242
+ # x_B1 = B_loader.next()
221
243
try :
222
244
x_A2 = A1_loader .next ()
223
- x_B2 = B1_loader .next ()
245
+ # x_B2 = B1_loader.next()
224
246
except StopIteration :
225
247
A1_loader = iter (self .a1_data_loader )
226
- B1_loader = iter (self .b1_data_loader )
248
+ # B1_loader = iter(self.b1_data_loader)
227
249
x_A2 = A1_loader .next ()
228
- x_B2 = B1_loader .next ()
250
+ #x_B2 = B1_loader.next()
251
+
252
+ x_A1 , x_B1 = x_A1 ['image' ], x_A1 ['edges' ]
253
+ x_A2 , x_B2 = x_A2 ['image' ], x_A2 ['edges' ]
229
254
if x_A1 .size (0 ) != x_B1 .size (0 ) or x_A2 .size (0 ) != x_B2 .size (0 ) or x_A1 .size (0 ) != x_A2 .size (0 ):
230
255
print ("[!] Sampled dataset from A and B have different # of data. Try resampling..." )
231
256
continue
@@ -422,17 +447,17 @@ def train(self):
422
447
423
448
self .generate_with_A (valid_x_A , valid_x_A1 , self .model_dir , idx = step )
424
449
self .generate_with_B (valid_x_A1 , valid_x_A , self .model_dir , idx = step )
425
- writer .add_scalars ('loss_G' , {'l_g' :l_g ,'l_gan_A' :l_gan_A ,'l_const_A' : l_const_A ,
426
- 'l_f' :l_f , 'l_const_AB' : l_const_AB },
450
+ writer .add_scalars ('loss_G' , {'l_g' :l_g ,'l_gan_A' :l_gan_A ,
451
+ 'l_f' :l_f },
427
452
step )
453
+ writer .add_scalars ('loss_F' , {'l_const_A' :l_const_A , 'l_const_AB' : l_const_AB }, step )
428
454
#'l_const_B':l_const_B,'l_const_AB':l_const_AB,'l_const_BA':l_const_BA}, step)
429
455
writer .add_scalars ('loss_D' , {'l_d_A' :l_d_A ,'l_d_B' :l_d_B }, step )
430
-
431
- if step % self .save_step == self .save_step - 1 :
456
+ if step % self .save_step == 0 :
432
457
print ("[*] Save models to {}..." .format (self .model_dir ))
433
458
434
- torch .save (self .G .state_dict (), '{}/G_AB_ {}.pth' .format (self .model_dir , step ))
435
- torch .save (self .F .state_dict (), '{}/G_BA_ {}.pth' .format (self .model_dir , step ))
459
+ torch .save (self .G .state_dict (), '{}/G_ {}.pth' .format (self .model_dir , step ))
460
+ torch .save (self .F .state_dict (), '{}/F_ {}.pth' .format (self .model_dir , step ))
436
461
437
462
torch .save (self .D_S .state_dict (), '{}/D_A_{}.pth' .format (self .model_dir , step ))
438
463
torch .save (self .D_H .state_dict (), '{}/D_B_{}.pth' .format (self .model_dir , step ))
@@ -447,6 +472,13 @@ def generate_with_A(self, inputs, input_ref, path, idx=None, tf_board=True):
447
472
#x_ABA_path = '{}/{}_x_ABA.png'.format(path, idx)
448
473
449
474
vutils .save_image (x_ABA .data , x_AB_path )
475
+ if not os .path .isdir ('{}/{}_A1' .format (path , idx )):
476
+ os .makedirs ('{}/{}_A1' .format (path , idx ))
477
+ for i in range (x_ABA .size (0 )):
478
+ tmp = x_ABA [i ].detach ().cpu ()
479
+ tmp = torchvision .transforms .ToPILImage ()(tmp )
480
+ tmp .save ('{}/{}_A1/{}.png' .format (path , idx , i ))
481
+
450
482
print ("[*] Samples saved: {}" .format (x_AB_path ))
451
483
if tf_board :
452
484
writer .add_image ('x_A1f' , x_AB [:16 ], idx )
@@ -468,6 +500,13 @@ def generate_with_B(self, inputs, input_ref, path, idx=None, tf_board=True):
468
500
#x_BAB_path = '{}/{}_x_BAB.png'.format(path, idx)
469
501
470
502
vutils .save_image (x_BAB .data , x_BA_path )
503
+ if not os .path .isdir ('{}/{}_A2' .format (path , idx )):
504
+ os .makedirs ('{}/{}_A2' .format (path , idx ))
505
+ for i in range (x_BAB .size (0 )):
506
+ tmp = x_BAB [i ].detach ().cpu ()
507
+ tmp = torchvision .transforms .ToPILImage ()(tmp )
508
+ tmp .save ('{}/{}_A2/{}.png' .format (path , idx , i ))
509
+
471
510
print ("[*] Samples saved: {}" .format (x_BA_path ))
472
511
if tf_board :
473
512
writer .add_image ('x_A2f' , x_BA [:16 ], idx )
0 commit comments