@@ -32,13 +32,13 @@ def weights_init(m):
32
32
m .bias .data .fill_ (0 )
33
33
34
34
class Trainer (object ):
35
- def __init__ (self , config , a_data_loader , b_data_loader , a1_data_loader , b1_data_loader ):
35
+ def __init__ (self , config , a_data_loader , a1_data_loader ):
36
36
self .config = config
37
37
38
38
self .a_data_loader = a_data_loader
39
- self .b_data_loader = b_data_loader
39
+ # self.b_data_loader = b_data_loader
40
40
self .a1_data_loader = a1_data_loader
41
- self .b1_data_loader = b1_data_loader
41
+ # self.b1_data_loader = b1_data_loader
42
42
43
43
self .num_gpu = config .num_gpu
44
44
self .dataset = config .dataset
@@ -90,7 +90,7 @@ def build_model(self):
90
90
self .D_B = DiscriminatorFC (2 , 1 , [config .fc_hidden_dim ] * config .d_num_layer )
91
91
else :
92
92
a_height , a_width , a_channel = self .a_data_loader .shape
93
- b_height , b_width , b_channel = self .b_data_loader .shape
93
+ b_height , b_width , b_channel = self .a1_data_loader .shape
94
94
95
95
if self .cnn_type == 0 :
96
96
#conv_dims, deconv_dims = [64, 128, 256, 512], [512, 256, 128, 64]
@@ -253,6 +253,10 @@ def train(self):
253
253
254
254
x_A1 , x_B1 = x_A1 ['image' ], x_A1 ['edges' ]
255
255
x_A2 , x_B2 = x_A2 ['image' ], x_A2 ['edges' ]
256
+ writer .add_image ('x_A1' , x_A1 [:16 ],step )
257
+ writer .add_image ('x_A2' , x_A2 [:16 ],step )
258
+ writer .add_image ('x_B1' , x_B1 [:16 ],step )
259
+ writer .add_image ('x_B2' , x_B2 [:16 ],step )
256
260
if x_A1 .size (0 ) != x_B1 .size (0 ) or x_A2 .size (0 ) != x_B2 .size (0 ) or x_A1 .size (0 ) != x_A2 .size (0 ):
257
261
print ("[!] Sampled dataset from A and B have different # of data. Try resampling..." )
258
262
continue
0 commit comments