Skip to content

Commit 435b85b

Browse files
author
kmbae
committed
kmbae
1 parent 8267e10 commit 435b85b

14 files changed

+307
-97
lines changed

__pycache__/config.cpython-35.pyc

2.55 KB
Binary file not shown.
3.09 KB
Binary file not shown.

__pycache__/models.cpython-35.pyc

6.88 KB
Binary file not shown.

__pycache__/trainer.cpython-35.pyc

12.9 KB
Binary file not shown.

__pycache__/utils.cpython-35.pyc

1.64 KB
Binary file not shown.

config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def add_argument_group(name):
3434
train_arg.add_argument('--is_train', type=str2bool, default=True)
3535
train_arg.add_argument('--optimizer', type=str, default='adam')
3636
train_arg.add_argument('--max_step', type=int, default=500000)
37-
train_arg.add_argument('--lr', type=float, default=0.0002)
37+
train_arg.add_argument('--lr', type=float, default=0.0001)
3838
train_arg.add_argument('--beta1', type=float, default=0.5)
3939
train_arg.add_argument('--beta2', type=float, default=0.999)
4040
train_arg.add_argument('--loss', type=str, default="log_prob",
@@ -44,8 +44,8 @@ def add_argument_group(name):
4444
# Misc
4545
misc_arg = add_argument_group('Misc')
4646
misc_arg.add_argument('--load_path', type=str, default='')
47-
misc_arg.add_argument('--log_step', type=int, default=50)
48-
misc_arg.add_argument('--save_step', type=int, default=500)
47+
misc_arg.add_argument('--log_step', type=int, default=100)
48+
misc_arg.add_argument('--save_step', type=int, default=1000)
4949
misc_arg.add_argument('--num_log_samples', type=int, default=3)
5050
misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'])
5151
misc_arg.add_argument('--log_dir', type=str, default='../logs')

config.pyc

-9 Bytes
Binary file not shown.

data_loader.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import numpy as np
33
from glob import glob
4-
from PIL import Image
4+
from PIL import Image, ImageFilter
55
from tqdm import tqdm
66

77
import torch
@@ -46,6 +46,7 @@ def pix2pix_split_images(root):
4646
class Dataset(torch.utils.data.Dataset):
4747
def __init__(self, root, scale_size, data_type, skip_pix2pix_processing=False):
4848
self.root = root
49+
self.data_type = data_type
4950
if not os.path.exists(self.root):
5051
raise Exception("[!] {} not exists.".format(root))
5152

@@ -66,6 +67,8 @@ def __init__(self, root, scale_size, data_type, skip_pix2pix_processing=False):
6667

6768
def __getitem__(self, index):
6869
image = Image.open(self.paths[index]).convert('RGB')
70+
#if self.data_type=='B':
71+
# image = image.filter(ImageFilter.MinFilter(3))
6972
return self.transform(image)
7073

7174
def __len__(self):

data_loader.pyc

43 Bytes
Binary file not shown.

main.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@ def main(config):
2626
data_path, batch_size, config.input_scale_size,
2727
config.num_worker, config.skip_pix2pix_processing)
2828

29-
trainer = Trainer(config, a_data_loader, b_data_loader)
29+
a1_data_loader, b1_data_loader = get_loader(
30+
'../data/edges2handbags', batch_size, config.input_scale_size,
31+
config.num_worker, config.skip_pix2pix_processing)
32+
33+
trainer = Trainer(config, a_data_loader, b_data_loader, a1_data_loader, b1_data_loader)
3034

3135
if config.is_train:
3236
save_config(config)

models.py

+77-4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,45 @@ def __init__(self, input_channel, output_channel, conv_dims, deconv_dims, num_gp
3131

3232
self.layer_module = nn.ModuleList(self.layers)
3333

34+
def main(self, x, y=None):
35+
if not y==None:
36+
out = torch.cat([x, y], dim=1)
37+
else:
38+
out = x
39+
for layer in self.layer_module:
40+
out = layer(out)
41+
return out
42+
43+
def forward(self, x, y=None):
44+
return self.main(x, y)
45+
46+
class GeneratorCNN_g(nn.Module):
47+
def __init__(self, input_channel, output_channel, conv_dims, deconv_dims, num_gpu):
48+
super(GeneratorCNN_g, self).__init__()
49+
self.num_gpu = num_gpu
50+
self.layers = []
51+
52+
prev_dim = conv_dims[0]
53+
self.layers.append(nn.Conv2d(input_channel, prev_dim, 4, 2, 1, bias=False))
54+
self.layers.append(nn.LeakyReLU(0.2, inplace=True))
55+
56+
for out_dim in conv_dims[1:]:
57+
self.layers.append(nn.Conv2d(prev_dim, out_dim, 4, 2, 1, bias=False))
58+
self.layers.append(nn.BatchNorm2d(out_dim))
59+
self.layers.append(nn.LeakyReLU(0.2, inplace=True))
60+
prev_dim = out_dim
61+
62+
for out_dim in deconv_dims:
63+
self.layers.append(nn.ConvTranspose2d(prev_dim, out_dim, 4, 2, 1, bias=False))
64+
self.layers.append(nn.BatchNorm2d(out_dim))
65+
self.layers.append(nn.ReLU(True))
66+
prev_dim = out_dim
67+
68+
self.layers.append(nn.ConvTranspose2d(prev_dim, output_channel, 4, 2, 1, bias=False))
69+
self.layers.append(nn.Sigmoid())#nn.Tanh())
70+
71+
self.layer_module = nn.ModuleList(self.layers)
72+
3473
def main(self, x, y):
3574
out = torch.cat([x, y], dim=1)
3675
for layer in self.layer_module:
@@ -61,14 +100,48 @@ def __init__(self, input_channel, output_channel, hidden_dims, num_gpu):
61100

62101
self.layer_module = nn.ModuleList(self.layers)
63102

64-
def main(self, x):
65-
out = x
103+
def main(self, x, y=None):
104+
if not y==None:
105+
out = torch.cat([x, y], dim=1)
106+
else:
107+
out = x
66108
for layer in self.layer_module:
67109
out = layer(out)
68110
return out.view(out.size(0), -1)
69111

70-
def forward(self, x):
71-
return self.main(x)
112+
def forward(self, x, y=None):
113+
return self.main(x,y)
114+
115+
116+
class DiscriminatorCNN_f(nn.Module):
117+
def __init__(self, input_channel, output_channel, hidden_dims, num_gpu):
118+
super(DiscriminatorCNN_f, self).__init__()
119+
self.num_gpu = num_gpu
120+
self.layers = []
121+
122+
prev_dim = hidden_dims[0]
123+
self.layers.append(nn.Conv2d(input_channel, prev_dim, 4, 2, 1, bias=False))
124+
self.layers.append(nn.LeakyReLU(0.2, inplace=True))
125+
126+
for out_dim in hidden_dims[1:]:
127+
self.layers.append(nn.Conv2d(prev_dim, out_dim, 4, 2, 1, bias=False))
128+
self.layers.append(nn.BatchNorm2d(out_dim))
129+
self.layers.append(nn.LeakyReLU(0.2, inplace=True))
130+
prev_dim = out_dim
131+
132+
self.layers.append(nn.Conv2d(prev_dim, output_channel, 4, 1, 0, bias=False))
133+
self.layers.append(nn.Sigmoid())
134+
135+
self.layer_module = nn.ModuleList(self.layers)
136+
137+
def main(self, x, y):
138+
out = torch.cat([x, y], dim=1)
139+
for layer in self.layer_module:
140+
out = layer(out)
141+
return out.view(out.size(0), -1)
142+
143+
def forward(self, x, y):
144+
return self.main(x,y)
72145

73146
class GeneratorFC(nn.Module):
74147
def __init__(self, input_size, output_size, hidden_dims):

models.pyc

1.45 KB
Binary file not shown.

0 commit comments

Comments
 (0)