-
Notifications
You must be signed in to change notification settings - Fork 541
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
11 changed files
with
2,577 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,273 @@ | ||
import utils, torch, time, os, pickle | ||
import numpy as np | ||
import torch.nn as nn | ||
import torch.optim as optim | ||
from torch.autograd import Variable | ||
|
||
class generator(nn.Module): | ||
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) | ||
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S | ||
def __init__(self, dataset = 'mnist'): | ||
super(generator, self).__init__() | ||
if dataset == 'mnist' or 'fashion-mnist': | ||
self.input_height = 28 | ||
self.input_width = 28 | ||
self.input_dim = 62 + 10 | ||
self.output_dim = 1 | ||
|
||
self.fc = nn.Sequential( | ||
nn.Linear(self.input_dim, 1024), | ||
nn.BatchNorm1d(1024), | ||
nn.ReLU(), | ||
nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), | ||
nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), | ||
nn.ReLU(), | ||
) | ||
self.deconv = nn.Sequential( | ||
nn.ConvTranspose2d(128, 64, 4, 2, 1), | ||
nn.BatchNorm2d(64), | ||
nn.ReLU(), | ||
nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), | ||
nn.Sigmoid(), | ||
) | ||
utils.initialize_weights(self) | ||
|
||
def forward(self, input, label): | ||
x = torch.cat([input, label], 1) | ||
x = self.fc(x) | ||
x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) | ||
x = self.deconv(x) | ||
|
||
return x | ||
|
||
class discriminator(nn.Module): | ||
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) | ||
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S | ||
def __init__(self, dataset='mnist'): | ||
super(discriminator, self).__init__() | ||
if dataset == 'mnist' or 'fashion-mnist': | ||
self.input_height = 28 | ||
self.input_width = 28 | ||
self.input_dim = 1 | ||
self.output_dim = 1 | ||
self.class_num = 10 | ||
|
||
self.conv = nn.Sequential( | ||
nn.Conv2d(self.input_dim, 64, 4, 2, 1), | ||
nn.LeakyReLU(0.2), | ||
nn.Conv2d(64, 128, 4, 2, 1), | ||
nn.BatchNorm2d(128), | ||
nn.LeakyReLU(0.2), | ||
) | ||
self.fc1 = nn.Sequential( | ||
nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024), | ||
nn.BatchNorm1d(1024), | ||
nn.LeakyReLU(0.2), | ||
) | ||
self.dc = nn.Sequential( | ||
nn.Linear(1024, self.output_dim), | ||
nn.Sigmoid(), | ||
) | ||
self.cl = nn.Sequential( | ||
nn.Linear(1024, self.class_num), | ||
) | ||
utils.initialize_weights(self) | ||
|
||
def forward(self, input): | ||
x = self.conv(input) | ||
x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4)) | ||
x = self.fc1(x) | ||
d = self.dc(x) | ||
c = self.cl(x) | ||
|
||
return d, c | ||
|
||
class ACGAN(object): | ||
def __init__(self, args): | ||
# parameters | ||
self.epoch = args.epoch | ||
self.sample_num = 100 | ||
self.batch_size = args.batch_size | ||
self.save_dir = args.save_dir | ||
self.result_dir = args.result_dir | ||
self.dataset = args.dataset | ||
self.log_dir = args.log_dir | ||
self.gpu_mode = args.gpu_mode | ||
self.model_name = args.gan_type | ||
|
||
# networks init | ||
self.G = generator(self.dataset) | ||
self.D = discriminator(self.dataset) | ||
self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) | ||
self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) | ||
|
||
if self.gpu_mode: | ||
self.G.cuda() | ||
self.D.cuda() | ||
self.BCE_loss = nn.BCELoss().cuda() | ||
self.CE_loss = nn.CrossEntropyLoss().cuda() | ||
else: | ||
self.BCE_loss = nn.BCELoss() | ||
self.CE_loss = nn.CrossEntropyLoss() | ||
|
||
print('---------- Networks architecture -------------') | ||
utils.print_network(self.G) | ||
utils.print_network(self.D) | ||
print('-----------------------------------------------') | ||
|
||
# load mnist | ||
self.data_X, self.data_Y = utils.load_mnist(args.dataset) | ||
self.z_dim = 62 | ||
self.y_dim = 10 | ||
|
||
# fixed noise & condition | ||
self.sample_z_ = torch.zeros((self.sample_num, self.z_dim)) | ||
for i in range(10): | ||
self.sample_z_[i*self.y_dim] = torch.rand(1, self.z_dim) | ||
for j in range(1, self.y_dim): | ||
self.sample_z_[i*self.y_dim + j] = self.sample_z_[i*self.y_dim] | ||
|
||
temp = torch.zeros((10, 1)) | ||
for i in range(self.y_dim): | ||
temp[i, 0] = i | ||
|
||
temp_y = torch.zeros((self.sample_num, 1)) | ||
for i in range(10): | ||
temp_y[i*self.y_dim: (i+1)*self.y_dim] = temp | ||
|
||
self.sample_y_ = torch.zeros((self.sample_num, self.y_dim)) | ||
self.sample_y_.scatter_(1, temp_y.type(torch.LongTensor), 1) | ||
if self.gpu_mode: | ||
self.sample_z_, self.sample_y_ = Variable(self.sample_z_.cuda(), volatile=True), Variable(self.sample_y_.cuda(), volatile=True) | ||
else: | ||
self.sample_z_, self.sample_y_ = Variable(self.sample_z_, volatile=True), Variable(self.sample_y_, volatile=True) | ||
|
||
def train(self): | ||
self.train_hist = {} | ||
self.train_hist['D_loss'] = [] | ||
self.train_hist['G_loss'] = [] | ||
self.train_hist['per_epoch_time'] = [] | ||
self.train_hist['total_time'] = [] | ||
|
||
if self.gpu_mode: | ||
self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) | ||
else: | ||
self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) | ||
|
||
self.D.train() | ||
print('training start!!') | ||
start_time = time.time() | ||
for epoch in range(self.epoch): | ||
self.G.train() | ||
epoch_start_time = time.time() | ||
for iter in range(len(self.data_X) // self.batch_size): | ||
x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] | ||
z_ = torch.rand((self.batch_size, self.z_dim)) | ||
y_vec_ = self.data_Y[iter*self.batch_size:(iter+1)*self.batch_size] | ||
|
||
if self.gpu_mode: | ||
x_, z_, y_vec_ = Variable(x_.cuda()), Variable(z_.cuda()), Variable(y_vec_.cuda()) | ||
else: | ||
x_, z_, y_vec_ = Variable(x_), Variable(z_), Variable(y_vec_) | ||
|
||
# update D network | ||
self.D_optimizer.zero_grad() | ||
|
||
D_real, C_real = self.D(x_) | ||
D_real_loss = self.BCE_loss(D_real, self.y_real_) | ||
C_real_loss = self.CE_loss(C_real, torch.max(y_vec_, 1)[1]) | ||
|
||
G_ = self.G(z_, y_vec_) | ||
D_fake, C_fake = self.D(G_) | ||
D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) | ||
C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1]) | ||
|
||
D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss | ||
self.train_hist['D_loss'].append(D_loss.data[0]) | ||
|
||
D_loss.backward() | ||
self.D_optimizer.step() | ||
|
||
# update G network | ||
self.G_optimizer.zero_grad() | ||
|
||
G_ = self.G(z_, y_vec_) | ||
D_fake, C_fake = self.D(G_) | ||
|
||
G_loss = self.BCE_loss(D_fake, self.y_real_) | ||
C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1]) | ||
|
||
G_loss += C_fake_loss | ||
self.train_hist['G_loss'].append(G_loss.data[0]) | ||
|
||
G_loss.backward() | ||
self.G_optimizer.step() | ||
|
||
if ((iter + 1) % 100) == 0: | ||
print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % | ||
((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) | ||
|
||
self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) | ||
self.visualize_results((epoch+1)) | ||
|
||
self.train_hist['total_time'].append(time.time() - start_time) | ||
print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), | ||
self.epoch, self.train_hist['total_time'][0])) | ||
print("Training finish!... save training results") | ||
|
||
self.save() | ||
utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, | ||
self.epoch) | ||
utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) | ||
|
||
def visualize_results(self, epoch, fix=True): | ||
self.G.eval() | ||
|
||
if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): | ||
os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) | ||
|
||
image_frame_dim = int(np.floor(np.sqrt(self.sample_num))) | ||
|
||
if fix: | ||
""" fixed noise """ | ||
samples = self.G(self.sample_z_, self.sample_y_) | ||
else: | ||
""" random noise """ | ||
temp = torch.LongTensor(self.batch_size, 1).random_() % 10 | ||
sample_y_ = torch.FloatTensor(self.batch_size, 10) | ||
sample_y_.zero_() | ||
sample_y_.scatter_(1, temp, 1) | ||
if self.gpu_mode: | ||
sample_z_, sample_y_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True), \ | ||
Variable(sample_y_.cuda(), volatile=True) | ||
else: | ||
sample_z_, sample_y_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True), \ | ||
Variable(sample_y_, volatile=True) | ||
|
||
samples = self.G(sample_z_, sample_y_) | ||
|
||
if self.gpu_mode: | ||
samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) | ||
else: | ||
samples = samples.data.numpy().transpose(0, 2, 3, 1) | ||
|
||
utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], | ||
self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') | ||
|
||
def save(self): | ||
save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) | ||
|
||
if not os.path.exists(save_dir): | ||
os.makedirs(save_dir) | ||
|
||
torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) | ||
torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) | ||
|
||
with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: | ||
pickle.dump(self.train_hist, f) | ||
|
||
def load(self): | ||
save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) | ||
|
||
self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) | ||
self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) |
Oops, something went wrong.