Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
znxlwm authored Sep 15, 2017
1 parent d1fdd74 commit 93aafee
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 28 deletions.
60 changes: 47 additions & 13 deletions DRAGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,24 @@
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable, grad
from torch.nn.init import xavier_normal
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class generator(nn.Module):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
def __init__(self, dataset = 'mnist'):
super(generator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 62
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 62
self.output_dim = 3

self.fc = nn.Sequential(
nn.Linear(self.input_dim, 1024),
Expand Down Expand Up @@ -45,11 +51,16 @@ class discriminator(nn.Module):
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
def __init__(self, dataset = 'mnist'):
super(discriminator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 1
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 3
self.output_dim = 1

self.conv = nn.Sequential(
nn.Conv2d(self.input_dim, 64, 4, 2, 1),
Expand Down Expand Up @@ -106,8 +117,21 @@ def __init__(self, args):
utils.print_network(self.D)
print('-----------------------------------------------')

# load mnist
self.data_X, self.data_Y = utils.load_mnist(args.dataset)
# load dataset
if self.dataset == 'mnist':
self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True,
transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'fashion-mnist':
self.data_loader = DataLoader(
datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'celebA':
self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose(
[transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size,
shuffle=True)
self.z_dim = 62

# fixed noise
Expand All @@ -134,8 +158,10 @@ def train(self):
for epoch in range(self.epoch):
epoch_start_time = time.time()
self.G.train()
for iter in range(len(self.data_X) // self.batch_size):
x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size]
for iter, (x_, _) in enumerate(self.data_loader):
if iter == self.data_loader.dataset.__len__() // self.batch_size:
break

z_ = torch.rand((self.batch_size, self.z_dim))

if self.gpu_mode:
Expand All @@ -155,15 +181,23 @@ def train(self):

""" DRAGAN Loss (Gradient penalty) """
# This is borrowed from https://github.com/jfsantos/dragan-pytorch/blob/master/dragan.py
alpha = torch.rand(x_.size()).cuda()
x_hat = Variable(alpha * x_.data + (1 - alpha) * (x_.data + 0.5 * x_.data.std() * torch.rand(x_.size()).cuda()),
if self.gpu_mode:
alpha = torch.rand(x_.size()).cuda()
x_hat = Variable(alpha * x_.data + (1 - alpha) * (x_.data + 0.5 * x_.data.std() * torch.rand(x_.size()).cuda()),
requires_grad=True)
else:
alpha = torch.rand(x_.size())
x_hat = Variable(alpha * x_.data + (1 - alpha) * (x_.data + 0.5 * x_.data.std() * torch.rand(x_.size())),
requires_grad=True)
pred_hat = self.D(x_hat)
gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(),
if self.gpu_mode:
gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(),
create_graph=True, retain_graph=True, only_inputs=True)[0]
else:
gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]

# gradients_penalty = self.lambda_ * (gradients.norm(2) - 1) ** 2 # DRAGAN does not work well when using norm(2).
gradient_penalty = self.lambda_ * ((torch.sqrt(torch.sum(torch.sum(torch.sum(gradients**2, 1), 1), 1)) - 1) ** 2).mean()
gradient_penalty = self.lambda_ * ((gradients.view(gradients.size()[0], -1).norm(2, 1) - 1) ** 2).mean()

D_loss = D_real_loss + D_fake_loss + gradient_penalty
self.train_hist['D_loss'].append(D_loss.data[0])
Expand All @@ -184,7 +218,7 @@ def train(self):

if ((iter + 1) % 100) == 0:
print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0]))
((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0]))

self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
self.visualize_results((epoch+1))
Expand Down
43 changes: 35 additions & 8 deletions GAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class generator(nn.Module):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
def __init__(self, dataset = 'mnist'):
super(generator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 62
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 62
self.output_dim = 3

self.fc = nn.Sequential(
nn.Linear(self.input_dim, 1024),
Expand Down Expand Up @@ -44,11 +51,16 @@ class discriminator(nn.Module):
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
def __init__(self, dataset = 'mnist'):
super(discriminator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 1
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 3
self.output_dim = 1

self.conv = nn.Sequential(
nn.Conv2d(self.input_dim, 64, 4, 2, 1),
Expand Down Expand Up @@ -77,7 +89,7 @@ class GAN(object):
def __init__(self, args):
# parameters
self.epoch = args.epoch
self.sample_num = 64
self.sample_num = 16
self.batch_size = args.batch_size
self.save_dir = args.save_dir
self.result_dir = args.result_dir
Expand All @@ -104,8 +116,21 @@ def __init__(self, args):
utils.print_network(self.D)
print('-----------------------------------------------')

# load mnist
self.data_X, self.data_Y = utils.load_mnist(args.dataset)
# load dataset
if self.dataset == 'mnist':
self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True,
transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'fashion-mnist':
self.data_loader = DataLoader(
datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'celebA':
self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose(
[transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size,
shuffle=True)
self.z_dim = 62

# fixed noise
Expand All @@ -132,8 +157,10 @@ def train(self):
for epoch in range(self.epoch):
self.G.train()
epoch_start_time = time.time()
for iter in range(len(self.data_X) // self.batch_size):
x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size]
for iter, (x_, _) in enumerate(self.data_loader):
if iter == self.data_loader.dataset.__len__() // self.batch_size:
break

z_ = torch.rand((self.batch_size, self.z_dim))

if self.gpu_mode:
Expand Down Expand Up @@ -170,7 +197,7 @@ def train(self):

if ((iter + 1) % 100) == 0:
print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0]))
((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0]))

self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
self.visualize_results((epoch+1))
Expand Down
41 changes: 34 additions & 7 deletions LSGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class generator(nn.Module):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
def __init__(self, dataset = 'mnist'):
super(generator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 62
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 62
self.output_dim = 3

self.fc = nn.Sequential(
nn.Linear(self.input_dim, 1024),
Expand Down Expand Up @@ -44,11 +51,16 @@ class discriminator(nn.Module):
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
def __init__(self, dataset = 'mnist'):
super(discriminator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 1
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 3
self.output_dim = 1

self.conv = nn.Sequential(
nn.Conv2d(self.input_dim, 64, 4, 2, 1),
Expand Down Expand Up @@ -104,8 +116,21 @@ def __init__(self, args):
utils.print_network(self.D)
print('-----------------------------------------------')

# load mnist
self.data_X, self.data_Y = utils.load_mnist(args.dataset)
# load dataset
if self.dataset == 'mnist':
self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True,
transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'fashion-mnist':
self.data_loader = DataLoader(
datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'celebA':
self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose(
[transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size,
shuffle=True)
self.z_dim = 62

# fixed noise
Expand All @@ -132,8 +157,10 @@ def train(self):
for epoch in range(self.epoch):
self.G.train()
epoch_start_time = time.time()
for iter in range(len(self.data_X) // self.batch_size):
x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size]
for iter, (x_, _) in enumerate(self.data_loader):
if iter == self.data_loader.dataset.__len__() // self.batch_size:
break

z_ = torch.rand((self.batch_size, self.z_dim))

if self.gpu_mode:
Expand Down Expand Up @@ -170,7 +197,7 @@ def train(self):

if ((iter + 1) % 100) == 0:
print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0]))
((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0]))

self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
self.visualize_results((epoch+1))
Expand Down

0 comments on commit 93aafee

Please sign in to comment.