Skip to content

Commit a8e7d53

Browse files
author
Dol+I
committed
hello
1 parent 39cbe25 commit a8e7d53

File tree

7 files changed

+701
-0
lines changed

7 files changed

+701
-0
lines changed

config.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#-*- coding: utf-8 -*-
2+
import argparse
3+
4+
def str2bool(v):
5+
return v.lower() in ('true', '1')
6+
7+
arg_lists = []
8+
parser = argparse.ArgumentParser()
9+
10+
def add_argument_group(name):
11+
arg = parser.add_argument_group(name)
12+
arg_lists.append(arg)
13+
return arg
14+
15+
# Network
16+
net_arg = add_argument_group('Network')
17+
net_arg.add_argument('--input_scale_size', type=int, default=64,
18+
help='input image will be resized with the given value as width and height')
19+
net_arg.add_argument('--g_num_layer', type=int, default=3)
20+
net_arg.add_argument('--d_num_layer', type=int, default=5)
21+
net_arg.add_argument('--cnn_type', type=int, default=0)
22+
net_arg.add_argument('--fc_hidden_dim', type=int, default=128, help='only for toy dataset')
23+
24+
# Data
25+
data_arg = add_argument_group('Data')
26+
data_arg.add_argument('--dataset', type=str, default='edges2shoes')
27+
data_arg.add_argument('--batch_size', type=int, default=200)
28+
data_arg.add_argument('--a_grayscale', type=str2bool, default=False)
29+
data_arg.add_argument('--b_grayscale', type=str2bool, default=False)
30+
data_arg.add_argument('--num_worker', type=int, default=12)
31+
32+
# Training / test parameters
33+
train_arg = add_argument_group('Training')
34+
train_arg.add_argument('--is_train', type=str2bool, default=True)
35+
train_arg.add_argument('--optimizer', type=str, default='adam')
36+
train_arg.add_argument('--max_step', type=int, default=500000)
37+
train_arg.add_argument('--lr', type=float, default=0.0002)
38+
train_arg.add_argument('--beta1', type=float, default=0.5)
39+
train_arg.add_argument('--beta2', type=float, default=0.999)
40+
train_arg.add_argument('--loss', type=str, default="log_prob",
41+
choices=["log_prob"], help="least square loss doesn't work well")
42+
train_arg.add_argument('--weight_decay', type=float, default=0.0001)
43+
44+
# Misc
45+
misc_arg = add_argument_group('Misc')
46+
misc_arg.add_argument('--load_path', type=str, default='')
47+
misc_arg.add_argument('--log_step', type=int, default=50)
48+
misc_arg.add_argument('--save_step', type=int, default=500)
49+
misc_arg.add_argument('--num_log_samples', type=int, default=3)
50+
misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'])
51+
misc_arg.add_argument('--log_dir', type=str, default='../logs')
52+
misc_arg.add_argument('--data_dir', type=str, default='../data')
53+
misc_arg.add_argument('--num_gpu', type=int, default=8)
54+
misc_arg.add_argument('--test_data_path', type=str, default=None,
55+
help='directory with images which will be used in test sample generation')
56+
misc_arg.add_argument('--sample_per_image', type=int, default=64,
57+
help='# of sample per image during test sample generation')
58+
misc_arg.add_argument('--random_seed', type=int, default=123)
59+
misc_arg.add_argument('--skip_pix2pix_processing', type=str2bool, default=False,
60+
help='just for fast debugging in poor cpu machine')
61+
62+
def get_config():
63+
config, unparsed = parser.parse_known_args()
64+
return config, unparsed

data_loader.py

+90
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import os
2+
import numpy as np
3+
from glob import glob
4+
from PIL import Image
5+
from tqdm import tqdm
6+
7+
import torch
8+
from torchvision import transforms
9+
10+
PIX2PIX_DATASETS = [
11+
'facades', 'cityscapes', 'maps', 'edges2shoes', 'edges2handbags']
12+
13+
def makedirs(path):
14+
if not os.path.exists(path):
15+
os.makedirs(path)
16+
17+
def pix2pix_split_images(root):
18+
paths = glob(os.path.join(root, "train/*"))
19+
20+
a_path = os.path.join(root, "A")
21+
b_path = os.path.join(root, "B")
22+
23+
makedirs(a_path)
24+
makedirs(b_path)
25+
26+
for path in tqdm(paths, desc="pix2pix processing"):
27+
filename = os.path.basename(path)
28+
29+
a_image_path = os.path.join(a_path, filename)
30+
b_image_path = os.path.join(b_path, filename)
31+
32+
if os.path.exists(a_image_path) and os.path.exists(b_image_path):
33+
continue
34+
35+
image = Image.open(os.path.join(path)).convert('RGB')
36+
data = np.array(image)
37+
38+
height, width, channel = data.shape
39+
40+
a_image = Image.fromarray(data[:,:width/2].astype(np.uint8))
41+
b_image = Image.fromarray(data[:,width/2:].astype(np.uint8))
42+
43+
a_image.save(a_image_path)
44+
b_image.save(b_image_path)
45+
46+
class Dataset(torch.utils.data.Dataset):
47+
def __init__(self, root, scale_size, data_type, skip_pix2pix_processing=False):
48+
self.root = root
49+
if not os.path.exists(self.root):
50+
raise Exception("[!] {} not exists.".format(root))
51+
52+
self.name = os.path.basename(root)
53+
if self.name in PIX2PIX_DATASETS and not skip_pix2pix_processing:
54+
pix2pix_split_images(self.root)
55+
56+
self.paths = glob(os.path.join(self.root, '{}/*'.format(data_type)))
57+
if len(self.paths) == 0:
58+
raise Exception("No images are found in {}".format(self.root))
59+
self.shape = list(Image.open(self.paths[0]).size) + [3]
60+
61+
self.transform = transforms.Compose([
62+
transforms.Scale(scale_size),
63+
transforms.ToTensor(),
64+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
65+
])
66+
67+
def __getitem__(self, index):
68+
image = Image.open(self.paths[index]).convert('RGB')
69+
return self.transform(image)
70+
71+
def __len__(self):
72+
return len(self.paths)
73+
74+
def get_loader(root, batch_size, scale_size, num_workers=2,
75+
skip_pix2pix_processing=False, shuffle=True):
76+
a_data_set, b_data_set = \
77+
Dataset(root, scale_size, "A", skip_pix2pix_processing), \
78+
Dataset(root, scale_size, "B", skip_pix2pix_processing)
79+
a_data_loader = torch.utils.data.DataLoader(dataset=a_data_set,
80+
batch_size=batch_size,
81+
shuffle=True,
82+
num_workers=num_workers)
83+
b_data_loader = torch.utils.data.DataLoader(dataset=b_data_set,
84+
batch_size=batch_size,
85+
shuffle=True,
86+
num_workers=num_workers)
87+
a_data_loader.shape = a_data_set.shape
88+
b_data_loader.shape = b_data_set.shape
89+
90+
return a_data_loader, b_data_loader

main.py

+41
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import torch
2+
3+
from trainer import Trainer
4+
from config import get_config
5+
from data_loader import get_loader
6+
from utils import prepare_dirs_and_logger, save_config
7+
8+
def main(config):
9+
prepare_dirs_and_logger(config)
10+
11+
torch.manual_seed(config.random_seed)
12+
if config.num_gpu > 0:
13+
torch.cuda.manual_seed(config.random_seed)
14+
15+
if config.is_train:
16+
data_path = config.data_path
17+
batch_size = config.batch_size
18+
else:
19+
if config.test_data_path is None:
20+
data_path = config.data_path
21+
else:
22+
data_path = config.test_data_path
23+
batch_size = config.sample_per_image
24+
25+
a_data_loader, b_data_loader = get_loader(
26+
data_path, batch_size, config.input_scale_size,
27+
config.num_worker, config.skip_pix2pix_processing)
28+
29+
trainer = Trainer(config, a_data_loader, b_data_loader)
30+
31+
if config.is_train:
32+
save_config(config)
33+
trainer.train()
34+
else:
35+
if not config.load_path:
36+
raise Exception("[!] You should specify `load_path` to load a pretrained model")
37+
trainer.test()
38+
39+
if __name__ == "__main__":
40+
config, unparsed = get_config()
41+
main(config)

models.py

+113
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
import torch
2+
from torch import nn
3+
import torch.nn.functional as F
4+
from torch.autograd import Variable
5+
from torch.utils.data import TensorDataset, DataLoader
6+
7+
class GeneratorCNN(nn.Module):
8+
def __init__(self, input_channel, output_channel, conv_dims, deconv_dims, num_gpu):
9+
super(GeneratorCNN, self).__init__()
10+
self.num_gpu = num_gpu
11+
self.layers = []
12+
13+
prev_dim = conv_dims[0]
14+
self.layers.append(nn.Conv2d(input_channel, prev_dim, 4, 2, 1, bias=False))
15+
self.layers.append(nn.LeakyReLU(0.2, inplace=True))
16+
17+
for out_dim in conv_dims[1:]:
18+
self.layers.append(nn.Conv2d(prev_dim, out_dim, 4, 2, 1, bias=False))
19+
self.layers.append(nn.BatchNorm2d(out_dim))
20+
self.layers.append(nn.LeakyReLU(0.2, inplace=True))
21+
prev_dim = out_dim
22+
23+
for out_dim in deconv_dims:
24+
self.layers.append(nn.ConvTranspose2d(prev_dim, out_dim, 4, 2, 1, bias=False))
25+
self.layers.append(nn.BatchNorm2d(out_dim))
26+
self.layers.append(nn.ReLU(True))
27+
prev_dim = out_dim
28+
29+
self.layers.append(nn.ConvTranspose2d(prev_dim, output_channel, 4, 2, 1, bias=False))
30+
self.layers.append(nn.Tanh())
31+
32+
self.layer_module = nn.ModuleList(self.layers)
33+
34+
def main(self, x):
35+
out = x
36+
for layer in self.layer_module:
37+
out = layer(out)
38+
return out
39+
40+
def forward(self, x):
41+
return self.main(x)
42+
43+
class DiscriminatorCNN(nn.Module):
44+
def __init__(self, input_channel, output_channel, hidden_dims, num_gpu):
45+
super(DiscriminatorCNN, self).__init__()
46+
self.num_gpu = num_gpu
47+
self.layers = []
48+
49+
prev_dim = hidden_dims[0]
50+
self.layers.append(nn.Conv2d(input_channel, prev_dim, 4, 2, 1, bias=False))
51+
self.layers.append(nn.LeakyReLU(0.2, inplace=True))
52+
53+
for out_dim in hidden_dims[1:]:
54+
self.layers.append(nn.Conv2d(prev_dim, out_dim, 4, 2, 1, bias=False))
55+
self.layers.append(nn.BatchNorm2d(out_dim))
56+
self.layers.append(nn.LeakyReLU(0.2, inplace=True))
57+
prev_dim = out_dim
58+
59+
self.layers.append(nn.Conv2d(prev_dim, output_channel, 4, 1, 0, bias=False))
60+
self.layers.append(nn.Sigmoid())
61+
62+
self.layer_module = nn.ModuleList(self.layers)
63+
64+
def main(self, x):
65+
out = x
66+
for layer in self.layer_module:
67+
out = layer(out)
68+
return out.view(out.size(0), -1)
69+
70+
def forward(self, x):
71+
return self.main(x)
72+
73+
class GeneratorFC(nn.Module):
74+
def __init__(self, input_size, output_size, hidden_dims):
75+
super(GeneratorFC, self).__init__()
76+
self.layers = []
77+
78+
prev_dim = input_size
79+
for hidden_dim in hidden_dims:
80+
self.layers.append(nn.Linear(prev_dim, hidden_dim))
81+
self.layers.append(nn.ReLU(True))
82+
prev_dim = hidden_dim
83+
self.layers.append(nn.Linear(prev_dim, output_size))
84+
85+
self.layer_module = nn.ModuleList(self.layers)
86+
87+
def forward(self, x):
88+
out = x
89+
for layer in self.layer_module:
90+
out = layer(out)
91+
return out
92+
93+
class DiscriminatorFC(nn.Module):
94+
def __init__(self, input_size, output_size, hidden_dims):
95+
super(DiscriminatorFC, self).__init__()
96+
self.layers = []
97+
98+
prev_dim = input_size
99+
for idx, hidden_dim in enumerate(hidden_dims):
100+
self.layers.append(nn.Linear(prev_dim, hidden_dim))
101+
self.layers.append(nn.ReLU(True))
102+
prev_dim = hidden_dim
103+
104+
self.layers.append(nn.Linear(prev_dim, output_size))
105+
self.layers.append(nn.Sigmoid())
106+
107+
self.layer_module = nn.ModuleList(self.layers)
108+
109+
def forward(self, x):
110+
out = x
111+
for layer in self.layer_module:
112+
out = layer(out)
113+
return out.view(-1, 1)

test

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
aaa

0 commit comments

Comments
 (0)