@@ -31,6 +31,45 @@ def __init__(self, input_channel, output_channel, conv_dims, deconv_dims, num_gp
31
31
32
32
self .layer_module = nn .ModuleList (self .layers )
33
33
34
+ def main (self , x , y = None ):
35
+ if not y == None :
36
+ out = torch .cat ([x , y ], dim = 1 )
37
+ else :
38
+ out = x
39
+ for layer in self .layer_module :
40
+ out = layer (out )
41
+ return out
42
+
43
+ def forward (self , x , y = None ):
44
+ return self .main (x , y )
45
+
46
+ class GeneratorCNN_g (nn .Module ):
47
+ def __init__ (self , input_channel , output_channel , conv_dims , deconv_dims , num_gpu ):
48
+ super (GeneratorCNN_g , self ).__init__ ()
49
+ self .num_gpu = num_gpu
50
+ self .layers = []
51
+
52
+ prev_dim = conv_dims [0 ]
53
+ self .layers .append (nn .Conv2d (input_channel , prev_dim , 4 , 2 , 1 , bias = False ))
54
+ self .layers .append (nn .LeakyReLU (0.2 , inplace = True ))
55
+
56
+ for out_dim in conv_dims [1 :]:
57
+ self .layers .append (nn .Conv2d (prev_dim , out_dim , 4 , 2 , 1 , bias = False ))
58
+ self .layers .append (nn .BatchNorm2d (out_dim ))
59
+ self .layers .append (nn .LeakyReLU (0.2 , inplace = True ))
60
+ prev_dim = out_dim
61
+
62
+ for out_dim in deconv_dims :
63
+ self .layers .append (nn .ConvTranspose2d (prev_dim , out_dim , 4 , 2 , 1 , bias = False ))
64
+ self .layers .append (nn .BatchNorm2d (out_dim ))
65
+ self .layers .append (nn .ReLU (True ))
66
+ prev_dim = out_dim
67
+
68
+ self .layers .append (nn .ConvTranspose2d (prev_dim , output_channel , 4 , 2 , 1 , bias = False ))
69
+ self .layers .append (nn .Sigmoid ())#nn.Tanh())
70
+
71
+ self .layer_module = nn .ModuleList (self .layers )
72
+
34
73
def main (self , x , y ):
35
74
out = torch .cat ([x , y ], dim = 1 )
36
75
for layer in self .layer_module :
@@ -61,14 +100,48 @@ def __init__(self, input_channel, output_channel, hidden_dims, num_gpu):
61
100
62
101
self .layer_module = nn .ModuleList (self .layers )
63
102
64
- def main (self , x ):
65
- out = x
103
+ def main (self , x , y = None ):
104
+ if not y == None :
105
+ out = torch .cat ([x , y ], dim = 1 )
106
+ else :
107
+ out = x
66
108
for layer in self .layer_module :
67
109
out = layer (out )
68
110
return out .view (out .size (0 ), - 1 )
69
111
70
- def forward (self , x ):
71
- return self .main (x )
112
+ def forward (self , x , y = None ):
113
+ return self .main (x ,y )
114
+
115
+
116
+ class DiscriminatorCNN_f (nn .Module ):
117
+ def __init__ (self , input_channel , output_channel , hidden_dims , num_gpu ):
118
+ super (DiscriminatorCNN_f , self ).__init__ ()
119
+ self .num_gpu = num_gpu
120
+ self .layers = []
121
+
122
+ prev_dim = hidden_dims [0 ]
123
+ self .layers .append (nn .Conv2d (input_channel , prev_dim , 4 , 2 , 1 , bias = False ))
124
+ self .layers .append (nn .LeakyReLU (0.2 , inplace = True ))
125
+
126
+ for out_dim in hidden_dims [1 :]:
127
+ self .layers .append (nn .Conv2d (prev_dim , out_dim , 4 , 2 , 1 , bias = False ))
128
+ self .layers .append (nn .BatchNorm2d (out_dim ))
129
+ self .layers .append (nn .LeakyReLU (0.2 , inplace = True ))
130
+ prev_dim = out_dim
131
+
132
+ self .layers .append (nn .Conv2d (prev_dim , output_channel , 4 , 1 , 0 , bias = False ))
133
+ self .layers .append (nn .Sigmoid ())
134
+
135
+ self .layer_module = nn .ModuleList (self .layers )
136
+
137
+ def main (self , x , y ):
138
+ out = torch .cat ([x , y ], dim = 1 )
139
+ for layer in self .layer_module :
140
+ out = layer (out )
141
+ return out .view (out .size (0 ), - 1 )
142
+
143
+ def forward (self , x , y ):
144
+ return self .main (x ,y )
72
145
73
146
class GeneratorFC (nn .Module ):
74
147
def __init__ (self , input_size , output_size , hidden_dims ):
0 commit comments