From 12ea10fd5f8258ecfa9d752f7e5f01b0446eef96 Mon Sep 17 00:00:00 2001 From: Hyeonwoo Kang Date: Thu, 14 Sep 2017 15:26:20 +0900 Subject: [PATCH] update WGAN, WGAN_GP, BEGAN and EBGAN --- BEGAN.py | 43 +++++++-- EBGAN.py | 49 +++++++--- WGAN.py | 104 ++++++++++++-------- WGAN_GP.py | 279 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 415 insertions(+), 60 deletions(-) create mode 100644 WGAN_GP.py diff --git a/BEGAN.py b/BEGAN.py index 500e2d1..35f5669 100644 --- a/BEGAN.py +++ b/BEGAN.py @@ -3,17 +3,24 @@ import torch.nn as nn import torch.optim as optim from torch.autograd import Variable +from torch.utils.data import DataLoader +from torchvision import datasets, transforms class generator(nn.Module): # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S def __init__(self, dataset = 'mnist'): super(generator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 62 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 62 + self.output_dim = 3 self.fc = nn.Sequential( nn.Linear(self.input_dim, 1024), @@ -44,11 +51,16 @@ class discriminator(nn.Module): # Architecture : (64)4c2s-FC32-FC64*14*14_BR-(1)4dc2s_S def __init__(self, dataset = 'mnist'): super(discriminator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 1 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 3 + self.output_dim = 3 self.conv = nn.Sequential( nn.Conv2d(self.input_dim, 64, 4, 2, 1), @@ -64,7 +76,7 @@ def __init__(self, dataset = 'mnist'): ) self.deconv = nn.Sequential( nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), - nn.Sigmoid(), + #nn.Sigmoid(), ) utils.initialize_weights(self) @@ -113,8 +125,21 @@ def __init__(self, args): utils.print_network(self.D) print('-----------------------------------------------') - # load mnist - self.data_X, self.data_Y = utils.load_mnist(args.dataset) + # load dataset + if self.dataset == 'mnist': + self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True, + transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'fashion-mnist': + self.data_loader = DataLoader( + datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'celebA': + self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose( + [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size, + shuffle=True) self.z_dim = 62 # fixed noise @@ -141,8 +166,10 @@ def train(self): for epoch in range(self.epoch): self.G.train() epoch_start_time = time.time() - for iter in range(len(self.data_X) // self.batch_size): - x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + for iter, (x_, _) in enumerate(self.data_loader): + if iter == self.data_loader.dataset.__len__() // self.batch_size: + break + z_ = torch.rand((self.batch_size, self.z_dim)) if self.gpu_mode: @@ -192,7 +219,7 @@ def train(self): if ((iter + 1) % 100) == 0: print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, M: %.8f, k: %.8f" % - ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0], self.M, self.k)) + ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0], self.M, self.k)) self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) self.visualize_results((epoch+1)) diff --git a/EBGAN.py b/EBGAN.py index 26e146b..be44825 100644 --- a/EBGAN.py +++ b/EBGAN.py @@ -3,17 +3,24 @@ import torch.nn as nn import torch.optim as optim from torch.autograd import Variable +from torch.utils.data import DataLoader +from torchvision import datasets, transforms class generator(nn.Module): # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S def __init__(self, dataset = 'mnist'): super(generator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 62 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 62 + self.output_dim = 3 self.fc = nn.Sequential( nn.Linear(self.input_dim, 1024), @@ -44,11 +51,16 @@ class discriminator(nn.Module): # Architecture : (64)4c2s-FC32-FC64*14*14_BR-(1)4dc2s_S def __init__(self, dataset = 'mnist'): super(discriminator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 1 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 3 + self.output_dim = 3 self.conv = nn.Sequential( nn.Conv2d(self.input_dim, 64, 4, 2, 1), @@ -64,7 +76,7 @@ def __init__(self, dataset = 'mnist'): ) self.deconv = nn.Sequential( nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), - # nn.Sigmoid(), # EBGAN does not work well when using Sigmoid(). + #nn.Sigmoid(), # EBGAN does not work well when using Sigmoid(). ) utils.initialize_weights(self) @@ -114,8 +126,21 @@ def __init__(self, args): utils.print_network(self.D) print('-----------------------------------------------') - # load mnist - self.data_X, self.data_Y = utils.load_mnist(args.dataset) + # load dataset + if self.dataset == 'mnist': + self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True, + transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'fashion-mnist': + self.data_loader = DataLoader( + datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'celebA': + self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose( + [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size, + shuffle=True) self.z_dim = 62 # fixed noise @@ -142,8 +167,10 @@ def train(self): for epoch in range(self.epoch): self.G.train() epoch_start_time = time.time() - for iter in range(len(self.data_X) // self.batch_size): - x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + for iter, (x_, _) in enumerate(self.data_loader): + if iter == self.data_loader.dataset.__len__() // self.batch_size: + break + z_ = torch.rand((self.batch_size, self.z_dim)) if self.gpu_mode: @@ -159,8 +186,7 @@ def train(self): G_ = self.G(z_) D_fake, D_fake_code = self.D(G_) - G_ = Variable(G_.data, requires_grad=False) - D_fake_err = self.MSE_loss(D_fake, G_) + D_fake_err = self.MSE_loss(D_fake, G_.detach()) if list(self.margin-D_fake_err.data)[0] > 0: D_loss = D_real_err + (self.margin - D_fake_err) else: @@ -175,8 +201,7 @@ def train(self): G_ = self.G(z_) D_fake, D_fake_code = self.D(G_) - G_ = Variable(G_.data, requires_grad=False) - D_fake_err = self.MSE_loss(D_fake, G_) + D_fake_err = self.MSE_loss(D_fake, G_.detach()) G_loss = D_fake_err + self.pt_loss_weight * self.pullaway_loss(D_fake_code) self.train_hist['G_loss'].append(G_loss.data[0]) @@ -185,7 +210,7 @@ def train(self): if ((iter + 1) % 100) == 0: print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % - ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0])) self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) self.visualize_results((epoch+1)) diff --git a/WGAN.py b/WGAN.py index ed07eba..89263a5 100644 --- a/WGAN.py +++ b/WGAN.py @@ -1,19 +1,26 @@ -import utils, torch, time, os, pickle, random +import utils, torch, time, os, pickle import numpy as np import torch.nn as nn import torch.optim as optim from torch.autograd import Variable +from torch.utils.data import DataLoader +from torchvision import datasets, transforms class generator(nn.Module): # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S def __init__(self, dataset = 'mnist'): super(generator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 62 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 62 + self.output_dim = 3 self.fc = nn.Sequential( nn.Linear(self.input_dim, 1024), @@ -44,11 +51,16 @@ class discriminator(nn.Module): # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S def __init__(self, dataset = 'mnist'): super(discriminator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 1 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 3 + self.output_dim = 1 self.conv = nn.Sequential( nn.Conv2d(self.input_dim, 64, 4, 2, 1), @@ -103,8 +115,21 @@ def __init__(self, args): utils.print_network(self.D) print('-----------------------------------------------') - # load mnist - self.data_X, self.data_Y = utils.load_mnist(args.dataset) + # load dataset + if self.dataset == 'mnist': + self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True, + transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'fashion-mnist': + self.data_loader = DataLoader( + datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'celebA': + self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose( + [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size, + shuffle=True) self.z_dim = 62 # fixed noise @@ -131,54 +156,53 @@ def train(self): for epoch in range(self.epoch): self.G.train() epoch_start_time = time.time() - for iter in range(len(self.data_X) // self.batch_size): - # update D network - D_losses = [] - for _ in range(self.n_critic): - inds = random.sample(range(0, len(self.data_X)), self.batch_size) - x_ = self.data_X[inds, :, :, :] - z_ = torch.rand((self.batch_size, self.z_dim)) + for iter, (x_, _) in enumerate(self.data_loader): + if iter == self.data_loader.dataset.__len__() // self.batch_size: + break - if self.gpu_mode: - x_, z_ = Variable(x_.cuda()), Variable(z_.cuda()) - else: - x_, z_ = Variable(x_), Variable(z_) + z_ = torch.rand((self.batch_size, self.z_dim)) - self.D_optimizer.zero_grad() + if self.gpu_mode: + x_, z_ = Variable(x_.cuda()), Variable(z_.cuda()) + else: + x_, z_ = Variable(x_), Variable(z_) - D_real = self.D(x_) - D_real_loss = -torch.mean(D_real) + # update D network + self.D_optimizer.zero_grad() - G_ = self.G(z_) - D_fake = self.D(G_) - D_fake_loss = torch.mean(D_fake) + D_real = self.D(x_) + D_real_loss = -torch.mean(D_real) - D_loss = D_real_loss + D_fake_loss - D_losses.append(D_loss.data[0]) + G_ = self.G(z_) + D_fake = self.D(G_) + D_fake_loss = torch.mean(D_fake) - D_loss.backward() - self.D_optimizer.step() + D_loss = D_real_loss + D_fake_loss - # clipping D - for p in self.D.parameters(): - p.data.clamp_(-self.c, self.c) + D_loss.backward() + self.D_optimizer.step() - self.train_hist['D_loss'].append(np.mean(D_losses)) + # clipping D + for p in self.D.parameters(): + p.data.clamp_(-self.c, self.c) - # update G network - self.G_optimizer.zero_grad() + if ((iter+1) % self.n_critic) == 0: + # update G network + self.G_optimizer.zero_grad() - G_ = self.G(z_) - D_fake = self.D(G_) - G_loss = -torch.mean(D_fake) - self.train_hist['G_loss'].append(G_loss.data[0]) + G_ = self.G(z_) + D_fake = self.D(G_) + G_loss = -torch.mean(D_fake) + self.train_hist['G_loss'].append(G_loss.data[0]) + + G_loss.backward() + self.G_optimizer.step() - G_loss.backward() - self.G_optimizer.step() + self.train_hist['D_loss'].append(D_loss.data[0]) if ((iter + 1) % 100) == 0: print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % - ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0])) self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) self.visualize_results((epoch+1)) @@ -238,4 +262,4 @@ def load(self): save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) - self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) \ No newline at end of file diff --git a/WGAN_GP.py b/WGAN_GP.py new file mode 100644 index 0000000..f1b6ed7 --- /dev/null +++ b/WGAN_GP.py @@ -0,0 +1,279 @@ +import utils, torch, time, os, pickle +import numpy as np +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable, grad +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + +class generator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(generator, self).__init__() + if dataset == 'mnist' or dataset == 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 62 + self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 62 + self.output_dim = 3 + + self.fc = nn.Sequential( + nn.Linear(self.input_dim, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), + nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.fc(input) + x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) + x = self.deconv(x) + + return x + +class discriminator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S + def __init__(self, dataset = 'mnist'): + super(discriminator, self).__init__() + if dataset == 'mnist' or dataset == 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 1 + self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 3 + self.output_dim = 1 + + self.conv = nn.Sequential( + nn.Conv2d(self.input_dim, 64, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Conv2d(64, 128, 4, 2, 1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + ) + self.fc = nn.Sequential( + nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024), + nn.BatchNorm1d(1024), + nn.LeakyReLU(0.2), + nn.Linear(1024, self.output_dim), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.conv(input) + x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4)) + x = self.fc(x) + + return x + +class WGAN_GP(object): + def __init__(self, args): + # parameters + self.epoch = args.epoch + self.sample_num = 64 + self.batch_size = args.batch_size + self.save_dir = args.save_dir + self.result_dir = args.result_dir + self.dataset = args.dataset + self.log_dir = args.log_dir + self.gpu_mode = args.gpu_mode + self.model_name = args.gan_type + self.lambda_ = 0.25 + self.n_critic = 5 # the number of iterations of the critic per generator iteration + + # networks init + self.G = generator(self.dataset) + self.D = discriminator(self.dataset) + self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) + self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) + + if self.gpu_mode: + self.G.cuda() + self.D.cuda() + + print('---------- Networks architecture -------------') + utils.print_network(self.G) + utils.print_network(self.D) + print('-----------------------------------------------') + + # load dataset + if self.dataset == 'mnist': + self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True, + transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'fashion-mnist': + self.data_loader = DataLoader( + datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'celebA': + self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose( + [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size, + shuffle=True) + self.z_dim = 62 + + # fixed noise + if self.gpu_mode: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + def train(self): + self.train_hist = {} + self.train_hist['D_loss'] = [] + self.train_hist['G_loss'] = [] + self.train_hist['per_epoch_time'] = [] + self.train_hist['total_time'] = [] + + if self.gpu_mode: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) + else: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) + + self.D.train() + print('training start!!') + start_time = time.time() + for epoch in range(self.epoch): + self.G.train() + epoch_start_time = time.time() + for iter, (x_, _) in enumerate(self.data_loader): + if iter == self.data_loader.dataset.__len__() // self.batch_size: + break + + z_ = torch.rand((self.batch_size, self.z_dim)) + + if self.gpu_mode: + x_, z_ = Variable(x_.cuda()), Variable(z_.cuda()) + else: + x_, z_ = Variable(x_), Variable(z_) + + # update D network + self.D_optimizer.zero_grad() + + D_real = self.D(x_) + D_real_loss = -torch.mean(D_real) + + G_ = self.G(z_) + D_fake = self.D(G_) + D_fake_loss = torch.mean(D_fake) + + # gradient penalty + if self.gpu_mode: + alpha = torch.rand(x_.size()).cuda() + else: + alpha = torch.rand(x_.size()) + + x_hat = Variable(alpha * x_.data + (1 - alpha) * G_.data, requires_grad=True) + + pred_hat = self.D(x_hat) + if self.gpu_mode: + gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(), + create_graph=True, retain_graph=True, only_inputs=True)[0] + else: + gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + gradient_penalty = self.lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean() + + D_loss = D_real_loss + D_fake_loss + gradient_penalty + + D_loss.backward() + self.D_optimizer.step() + + if ((iter+1) % self.n_critic) == 0: + # update G network + self.G_optimizer.zero_grad() + + G_ = self.G(z_) + D_fake = self.D(G_) + G_loss = -torch.mean(D_fake) + self.train_hist['G_loss'].append(G_loss.data[0]) + + G_loss.backward() + self.G_optimizer.step() + + self.train_hist['D_loss'].append(D_loss.data[0]) + + if ((iter + 1) % 100) == 0: + print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % + ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0])) + + self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) + self.visualize_results((epoch+1)) + + self.train_hist['total_time'].append(time.time() - start_time) + print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), + self.epoch, self.train_hist['total_time'][0])) + print("Training finish!... save training results") + + self.save() + utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, + self.epoch) + utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) + + def visualize_results(self, epoch, fix=True): + self.G.eval() + + if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): + os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) + + tot_num_samples = min(self.sample_num, self.batch_size) + image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) + + if fix: + """ fixed noise """ + samples = self.G(self.sample_z_) + else: + """ random noise """ + if self.gpu_mode: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + samples = self.G(sample_z_) + + if self.gpu_mode: + samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) + else: + samples = samples.data.numpy().transpose(0, 2, 3, 1) + + utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], + self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') + + def save(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) + torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) + + with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: + pickle.dump(self.train_hist, f) + + def load(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) \ No newline at end of file