From 93aafee86458878ed99c0fbb977b3dc3e8e32a94 Mon Sep 17 00:00:00 2001 From: Hyeonwoo Kang Date: Fri, 15 Sep 2017 17:58:38 +0900 Subject: [PATCH] Add files via upload --- DRAGAN.py | 60 +++++++++++++++++++++++++++++++++++++++++++------------ GAN.py | 43 +++++++++++++++++++++++++++++++-------- LSGAN.py | 41 ++++++++++++++++++++++++++++++------- 3 files changed, 116 insertions(+), 28 deletions(-) diff --git a/DRAGAN.py b/DRAGAN.py index 161d941..6aea79a 100644 --- a/DRAGAN.py +++ b/DRAGAN.py @@ -3,18 +3,24 @@ import torch.nn as nn import torch.optim as optim from torch.autograd import Variable, grad -from torch.nn.init import xavier_normal +from torch.utils.data import DataLoader +from torchvision import datasets, transforms class generator(nn.Module): # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S def __init__(self, dataset = 'mnist'): super(generator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 62 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 62 + self.output_dim = 3 self.fc = nn.Sequential( nn.Linear(self.input_dim, 1024), @@ -45,11 +51,16 @@ class discriminator(nn.Module): # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S def __init__(self, dataset = 'mnist'): super(discriminator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 1 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 3 + self.output_dim = 1 self.conv = nn.Sequential( nn.Conv2d(self.input_dim, 64, 4, 2, 1), @@ -106,8 +117,21 @@ def __init__(self, args): utils.print_network(self.D) print('-----------------------------------------------') - # load mnist - self.data_X, self.data_Y = utils.load_mnist(args.dataset) + # load dataset + if self.dataset == 'mnist': + self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True, + transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'fashion-mnist': + self.data_loader = DataLoader( + datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'celebA': + self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose( + [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size, + shuffle=True) self.z_dim = 62 # fixed noise @@ -134,8 +158,10 @@ def train(self): for epoch in range(self.epoch): epoch_start_time = time.time() self.G.train() - for iter in range(len(self.data_X) // self.batch_size): - x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + for iter, (x_, _) in enumerate(self.data_loader): + if iter == self.data_loader.dataset.__len__() // self.batch_size: + break + z_ = torch.rand((self.batch_size, self.z_dim)) if self.gpu_mode: @@ -155,15 +181,23 @@ def train(self): """ DRAGAN Loss (Gradient penalty) """ # This is borrowed from https://github.com/jfsantos/dragan-pytorch/blob/master/dragan.py - alpha = torch.rand(x_.size()).cuda() - x_hat = Variable(alpha * x_.data + (1 - alpha) * (x_.data + 0.5 * x_.data.std() * torch.rand(x_.size()).cuda()), + if self.gpu_mode: + alpha = torch.rand(x_.size()).cuda() + x_hat = Variable(alpha * x_.data + (1 - alpha) * (x_.data + 0.5 * x_.data.std() * torch.rand(x_.size()).cuda()), requires_grad=True) + else: + alpha = torch.rand(x_.size()) + x_hat = Variable(alpha * x_.data + (1 - alpha) * (x_.data + 0.5 * x_.data.std() * torch.rand(x_.size())), + requires_grad=True) pred_hat = self.D(x_hat) - gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(), + if self.gpu_mode: + gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(), create_graph=True, retain_graph=True, only_inputs=True)[0] + else: + gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()), + create_graph=True, retain_graph=True, only_inputs=True)[0] - # gradients_penalty = self.lambda_ * (gradients.norm(2) - 1) ** 2 # DRAGAN does not work well when using norm(2). - gradient_penalty = self.lambda_ * ((torch.sqrt(torch.sum(torch.sum(torch.sum(gradients**2, 1), 1), 1)) - 1) ** 2).mean() + gradient_penalty = self.lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean() D_loss = D_real_loss + D_fake_loss + gradient_penalty self.train_hist['D_loss'].append(D_loss.data[0]) @@ -184,7 +218,7 @@ def train(self): if ((iter + 1) % 100) == 0: print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % - ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0])) self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) self.visualize_results((epoch+1)) diff --git a/GAN.py b/GAN.py index 2fe40e1..998ddd0 100644 --- a/GAN.py +++ b/GAN.py @@ -3,17 +3,24 @@ import torch.nn as nn import torch.optim as optim from torch.autograd import Variable +from torch.utils.data import DataLoader +from torchvision import datasets, transforms class generator(nn.Module): # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S def __init__(self, dataset = 'mnist'): super(generator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 62 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 62 + self.output_dim = 3 self.fc = nn.Sequential( nn.Linear(self.input_dim, 1024), @@ -44,11 +51,16 @@ class discriminator(nn.Module): # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S def __init__(self, dataset = 'mnist'): super(discriminator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 1 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 3 + self.output_dim = 1 self.conv = nn.Sequential( nn.Conv2d(self.input_dim, 64, 4, 2, 1), @@ -77,7 +89,7 @@ class GAN(object): def __init__(self, args): # parameters self.epoch = args.epoch - self.sample_num = 64 + self.sample_num = 16 self.batch_size = args.batch_size self.save_dir = args.save_dir self.result_dir = args.result_dir @@ -104,8 +116,21 @@ def __init__(self, args): utils.print_network(self.D) print('-----------------------------------------------') - # load mnist - self.data_X, self.data_Y = utils.load_mnist(args.dataset) + # load dataset + if self.dataset == 'mnist': + self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True, + transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'fashion-mnist': + self.data_loader = DataLoader( + datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'celebA': + self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose( + [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size, + shuffle=True) self.z_dim = 62 # fixed noise @@ -132,8 +157,10 @@ def train(self): for epoch in range(self.epoch): self.G.train() epoch_start_time = time.time() - for iter in range(len(self.data_X) // self.batch_size): - x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + for iter, (x_, _) in enumerate(self.data_loader): + if iter == self.data_loader.dataset.__len__() // self.batch_size: + break + z_ = torch.rand((self.batch_size, self.z_dim)) if self.gpu_mode: @@ -170,7 +197,7 @@ def train(self): if ((iter + 1) % 100) == 0: print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % - ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0])) self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) self.visualize_results((epoch+1)) diff --git a/LSGAN.py b/LSGAN.py index 4e420b3..30a0663 100644 --- a/LSGAN.py +++ b/LSGAN.py @@ -3,17 +3,24 @@ import torch.nn as nn import torch.optim as optim from torch.autograd import Variable +from torch.utils.data import DataLoader +from torchvision import datasets, transforms class generator(nn.Module): # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S def __init__(self, dataset = 'mnist'): super(generator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 62 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 62 + self.output_dim = 3 self.fc = nn.Sequential( nn.Linear(self.input_dim, 1024), @@ -44,11 +51,16 @@ class discriminator(nn.Module): # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S def __init__(self, dataset = 'mnist'): super(discriminator, self).__init__() - if dataset == 'mnist' or 'fashion-mnist': + if dataset == 'mnist' or dataset == 'fashion-mnist': self.input_height = 28 self.input_width = 28 self.input_dim = 1 self.output_dim = 1 + elif dataset == 'celebA': + self.input_height = 64 + self.input_width = 64 + self.input_dim = 3 + self.output_dim = 1 self.conv = nn.Sequential( nn.Conv2d(self.input_dim, 64, 4, 2, 1), @@ -104,8 +116,21 @@ def __init__(self, args): utils.print_network(self.D) print('-----------------------------------------------') - # load mnist - self.data_X, self.data_Y = utils.load_mnist(args.dataset) + # load dataset + if self.dataset == 'mnist': + self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True, + transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'fashion-mnist': + self.data_loader = DataLoader( + datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose( + [transforms.ToTensor()])), + batch_size=self.batch_size, shuffle=True) + elif self.dataset == 'celebA': + self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose( + [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size, + shuffle=True) self.z_dim = 62 # fixed noise @@ -132,8 +157,10 @@ def train(self): for epoch in range(self.epoch): self.G.train() epoch_start_time = time.time() - for iter in range(len(self.data_X) // self.batch_size): - x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + for iter, (x_, _) in enumerate(self.data_loader): + if iter == self.data_loader.dataset.__len__() // self.batch_size: + break + z_ = torch.rand((self.batch_size, self.z_dim)) if self.gpu_mode: @@ -170,7 +197,7 @@ def train(self): if ((iter + 1) % 100) == 0: print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % - ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + ((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0])) self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) self.visualize_results((epoch+1))