Skip to content

Commit

Permalink
update WGAN, WGAN_GP, BEGAN and EBGAN
Browse files Browse the repository at this point in the history
  • Loading branch information
znxlwm authored Sep 14, 2017
1 parent 4cc7528 commit 12ea10f
Show file tree
Hide file tree
Showing 4 changed files with 415 additions and 60 deletions.
43 changes: 35 additions & 8 deletions BEGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class generator(nn.Module):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
def __init__(self, dataset = 'mnist'):
super(generator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 62
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 62
self.output_dim = 3

self.fc = nn.Sequential(
nn.Linear(self.input_dim, 1024),
Expand Down Expand Up @@ -44,11 +51,16 @@ class discriminator(nn.Module):
# Architecture : (64)4c2s-FC32-FC64*14*14_BR-(1)4dc2s_S
def __init__(self, dataset = 'mnist'):
super(discriminator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 1
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 3
self.output_dim = 3

self.conv = nn.Sequential(
nn.Conv2d(self.input_dim, 64, 4, 2, 1),
Expand All @@ -64,7 +76,7 @@ def __init__(self, dataset = 'mnist'):
)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
nn.Sigmoid(),
#nn.Sigmoid(),
)
utils.initialize_weights(self)

Expand Down Expand Up @@ -113,8 +125,21 @@ def __init__(self, args):
utils.print_network(self.D)
print('-----------------------------------------------')

# load mnist
self.data_X, self.data_Y = utils.load_mnist(args.dataset)
# load dataset
if self.dataset == 'mnist':
self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True,
transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'fashion-mnist':
self.data_loader = DataLoader(
datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'celebA':
self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose(
[transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size,
shuffle=True)
self.z_dim = 62

# fixed noise
Expand All @@ -141,8 +166,10 @@ def train(self):
for epoch in range(self.epoch):
self.G.train()
epoch_start_time = time.time()
for iter in range(len(self.data_X) // self.batch_size):
x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size]
for iter, (x_, _) in enumerate(self.data_loader):
if iter == self.data_loader.dataset.__len__() // self.batch_size:
break

z_ = torch.rand((self.batch_size, self.z_dim))

if self.gpu_mode:
Expand Down Expand Up @@ -192,7 +219,7 @@ def train(self):

if ((iter + 1) % 100) == 0:
print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, M: %.8f, k: %.8f" %
((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0], self.M, self.k))
((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0], self.M, self.k))

self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
self.visualize_results((epoch+1))
Expand Down
49 changes: 37 additions & 12 deletions EBGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,24 @@
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class generator(nn.Module):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
def __init__(self, dataset = 'mnist'):
super(generator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 62
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 62
self.output_dim = 3

self.fc = nn.Sequential(
nn.Linear(self.input_dim, 1024),
Expand Down Expand Up @@ -44,11 +51,16 @@ class discriminator(nn.Module):
# Architecture : (64)4c2s-FC32-FC64*14*14_BR-(1)4dc2s_S
def __init__(self, dataset = 'mnist'):
super(discriminator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 1
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 3
self.output_dim = 3

self.conv = nn.Sequential(
nn.Conv2d(self.input_dim, 64, 4, 2, 1),
Expand All @@ -64,7 +76,7 @@ def __init__(self, dataset = 'mnist'):
)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
# nn.Sigmoid(), # EBGAN does not work well when using Sigmoid().
#nn.Sigmoid(), # EBGAN does not work well when using Sigmoid().
)
utils.initialize_weights(self)

Expand Down Expand Up @@ -114,8 +126,21 @@ def __init__(self, args):
utils.print_network(self.D)
print('-----------------------------------------------')

# load mnist
self.data_X, self.data_Y = utils.load_mnist(args.dataset)
# load dataset
if self.dataset == 'mnist':
self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True,
transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'fashion-mnist':
self.data_loader = DataLoader(
datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'celebA':
self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose(
[transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size,
shuffle=True)
self.z_dim = 62

# fixed noise
Expand All @@ -142,8 +167,10 @@ def train(self):
for epoch in range(self.epoch):
self.G.train()
epoch_start_time = time.time()
for iter in range(len(self.data_X) // self.batch_size):
x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size]
for iter, (x_, _) in enumerate(self.data_loader):
if iter == self.data_loader.dataset.__len__() // self.batch_size:
break

z_ = torch.rand((self.batch_size, self.z_dim))

if self.gpu_mode:
Expand All @@ -159,8 +186,7 @@ def train(self):

G_ = self.G(z_)
D_fake, D_fake_code = self.D(G_)
G_ = Variable(G_.data, requires_grad=False)
D_fake_err = self.MSE_loss(D_fake, G_)
D_fake_err = self.MSE_loss(D_fake, G_.detach())
if list(self.margin-D_fake_err.data)[0] > 0:
D_loss = D_real_err + (self.margin - D_fake_err)
else:
Expand All @@ -175,8 +201,7 @@ def train(self):

G_ = self.G(z_)
D_fake, D_fake_code = self.D(G_)
G_ = Variable(G_.data, requires_grad=False)
D_fake_err = self.MSE_loss(D_fake, G_)
D_fake_err = self.MSE_loss(D_fake, G_.detach())
G_loss = D_fake_err + self.pt_loss_weight * self.pullaway_loss(D_fake_code)
self.train_hist['G_loss'].append(G_loss.data[0])

Expand All @@ -185,7 +210,7 @@ def train(self):

if ((iter + 1) % 100) == 0:
print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0]))
((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0]))

self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
self.visualize_results((epoch+1))
Expand Down
104 changes: 64 additions & 40 deletions WGAN.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import utils, torch, time, os, pickle, random
import utils, torch, time, os, pickle
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

class generator(nn.Module):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
def __init__(self, dataset = 'mnist'):
super(generator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 62
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 62
self.output_dim = 3

self.fc = nn.Sequential(
nn.Linear(self.input_dim, 1024),
Expand Down Expand Up @@ -44,11 +51,16 @@ class discriminator(nn.Module):
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
def __init__(self, dataset = 'mnist'):
super(discriminator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
if dataset == 'mnist' or dataset == 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 1
self.output_dim = 1
elif dataset == 'celebA':
self.input_height = 64
self.input_width = 64
self.input_dim = 3
self.output_dim = 1

self.conv = nn.Sequential(
nn.Conv2d(self.input_dim, 64, 4, 2, 1),
Expand Down Expand Up @@ -103,8 +115,21 @@ def __init__(self, args):
utils.print_network(self.D)
print('-----------------------------------------------')

# load mnist
self.data_X, self.data_Y = utils.load_mnist(args.dataset)
# load dataset
if self.dataset == 'mnist':
self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True,
transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'fashion-mnist':
self.data_loader = DataLoader(
datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose(
[transforms.ToTensor()])),
batch_size=self.batch_size, shuffle=True)
elif self.dataset == 'celebA':
self.data_loader = utils.load_celebA('data/celebA', transform=transforms.Compose(
[transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=self.batch_size,
shuffle=True)
self.z_dim = 62

# fixed noise
Expand All @@ -131,54 +156,53 @@ def train(self):
for epoch in range(self.epoch):
self.G.train()
epoch_start_time = time.time()
for iter in range(len(self.data_X) // self.batch_size):
# update D network
D_losses = []
for _ in range(self.n_critic):
inds = random.sample(range(0, len(self.data_X)), self.batch_size)
x_ = self.data_X[inds, :, :, :]
z_ = torch.rand((self.batch_size, self.z_dim))
for iter, (x_, _) in enumerate(self.data_loader):
if iter == self.data_loader.dataset.__len__() // self.batch_size:
break

if self.gpu_mode:
x_, z_ = Variable(x_.cuda()), Variable(z_.cuda())
else:
x_, z_ = Variable(x_), Variable(z_)
z_ = torch.rand((self.batch_size, self.z_dim))

self.D_optimizer.zero_grad()
if self.gpu_mode:
x_, z_ = Variable(x_.cuda()), Variable(z_.cuda())
else:
x_, z_ = Variable(x_), Variable(z_)

D_real = self.D(x_)
D_real_loss = -torch.mean(D_real)
# update D network
self.D_optimizer.zero_grad()

G_ = self.G(z_)
D_fake = self.D(G_)
D_fake_loss = torch.mean(D_fake)
D_real = self.D(x_)
D_real_loss = -torch.mean(D_real)

D_loss = D_real_loss + D_fake_loss
D_losses.append(D_loss.data[0])
G_ = self.G(z_)
D_fake = self.D(G_)
D_fake_loss = torch.mean(D_fake)

D_loss.backward()
self.D_optimizer.step()
D_loss = D_real_loss + D_fake_loss

# clipping D
for p in self.D.parameters():
p.data.clamp_(-self.c, self.c)
D_loss.backward()
self.D_optimizer.step()

self.train_hist['D_loss'].append(np.mean(D_losses))
# clipping D
for p in self.D.parameters():
p.data.clamp_(-self.c, self.c)

# update G network
self.G_optimizer.zero_grad()
if ((iter+1) % self.n_critic) == 0:
# update G network
self.G_optimizer.zero_grad()

G_ = self.G(z_)
D_fake = self.D(G_)
G_loss = -torch.mean(D_fake)
self.train_hist['G_loss'].append(G_loss.data[0])
G_ = self.G(z_)
D_fake = self.D(G_)
G_loss = -torch.mean(D_fake)
self.train_hist['G_loss'].append(G_loss.data[0])

G_loss.backward()
self.G_optimizer.step()

G_loss.backward()
self.G_optimizer.step()
self.train_hist['D_loss'].append(D_loss.data[0])

if ((iter + 1) % 100) == 0:
print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0]))
((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.data[0], G_loss.data[0]))

self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
self.visualize_results((epoch+1))
Expand Down Expand Up @@ -238,4 +262,4 @@ def load(self):
save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)

self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
Loading

0 comments on commit 12ea10f

Please sign in to comment.