Skip to content

Commit

Permalink
Update to Pytorch 0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
znxlwm authored Jun 12, 2018
1 parent 5207b2f commit 4caaace
Show file tree
Hide file tree
Showing 12 changed files with 563 additions and 755 deletions.
119 changes: 55 additions & 64 deletions ACGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,55 +2,52 @@
import numpy as np
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from dataloader import dataloader

class generator(nn.Module):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
def __init__(self, dataset = 'mnist'):
def __init__(self, input_dim=100, output_dim=1, input_size=32, class_num=10):
super(generator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 62 + 10
self.output_dim = 1
self.input_dim = input_dim
self.output_dim = output_dim
self.input_size = input_size
self.class_num = class_num

self.fc = nn.Sequential(
nn.Linear(self.input_dim, 1024),
nn.Linear(self.input_dim + self.class_num, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)),
nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)),
nn.Linear(1024, 128 * (self.input_size // 4) * (self.input_size // 4)),
nn.BatchNorm1d(128 * (self.input_size // 4) * (self.input_size // 4)),
nn.ReLU(),
)
self.deconv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
nn.Sigmoid(),
nn.Tanh(),
)
utils.initialize_weights(self)

def forward(self, input, label):
x = torch.cat([input, label], 1)
x = self.fc(x)
x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4))
x = x.view(-1, 128, (self.input_size // 4), (self.input_size // 4))
x = self.deconv(x)

return x

class discriminator(nn.Module):
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
def __init__(self, dataset='mnist'):
def __init__(self, input_dim=1, output_dim=1, input_size=32, class_num=10):
super(discriminator, self).__init__()
if dataset == 'mnist' or 'fashion-mnist':
self.input_height = 28
self.input_width = 28
self.input_dim = 1
self.output_dim = 1
self.class_num = 10
self.input_dim = input_dim
self.output_dim = output_dim
self.input_size = input_size
self.class_num = class_num

self.conv = nn.Sequential(
nn.Conv2d(self.input_dim, 64, 4, 2, 1),
Expand All @@ -60,7 +57,7 @@ def __init__(self, dataset='mnist'):
nn.LeakyReLU(0.2),
)
self.fc1 = nn.Sequential(
nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024),
nn.Linear(128 * (self.input_size // 4) * (self.input_size // 4), 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2),
)
Expand All @@ -75,7 +72,7 @@ def __init__(self, dataset='mnist'):

def forward(self, input):
x = self.conv(input)
x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4))
x = x.view(-1, 128 * (self.input_size // 4) * (self.input_size // 4))
x = self.fc1(x)
d = self.dc(x)
c = self.cl(x)
Expand All @@ -94,10 +91,18 @@ def __init__(self, args):
self.log_dir = args.log_dir
self.gpu_mode = args.gpu_mode
self.model_name = args.gan_type
self.input_size = args.input_size
self.z_dim = 62
self.class_num = 10
self.sample_num = self.class_num ** 2

# load dataset
self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size)
data = self.data_loader.__iter__().__next__()[0]

# networks init
self.G = generator(self.dataset)
self.D = discriminator(self.dataset)
self.G = generator(input_dim=self.z_dim, output_dim=data.shape[1], input_size=self.input_size)
self.D = discriminator(input_dim=data.shape[1], output_dim=1, input_size=self.input_size)
self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))

Expand All @@ -115,32 +120,24 @@ def __init__(self, args):
utils.print_network(self.D)
print('-----------------------------------------------')

# load mnist
self.data_X, self.data_Y = utils.load_mnist(args.dataset)
self.z_dim = 62
self.y_dim = 10

# fixed noise & condition
self.sample_z_ = torch.zeros((self.sample_num, self.z_dim))
for i in range(10):
self.sample_z_[i*self.y_dim] = torch.rand(1, self.z_dim)
for j in range(1, self.y_dim):
self.sample_z_[i*self.y_dim + j] = self.sample_z_[i*self.y_dim]
for i in range(self.class_num):
self.sample_z_[i*self.class_num] = torch.rand(1, self.z_dim)
for j in range(1, self.class_num):
self.sample_z_[i*self.class_num + j] = self.sample_z_[i*self.class_num]

temp = torch.zeros((10, 1))
for i in range(self.y_dim):
temp = torch.zeros((self.class_num, 1))
for i in range(self.class_num):
temp[i, 0] = i

temp_y = torch.zeros((self.sample_num, 1))
for i in range(10):
temp_y[i*self.y_dim: (i+1)*self.y_dim] = temp
for i in range(self.class_num):
temp_y[i*self.class_num: (i+1)*self.class_num] = temp

self.sample_y_ = torch.zeros((self.sample_num, self.y_dim))
self.sample_y_.scatter_(1, temp_y.type(torch.LongTensor), 1)
self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1)
if self.gpu_mode:
self.sample_z_, self.sample_y_ = Variable(self.sample_z_.cuda(), volatile=True), Variable(self.sample_y_.cuda(), volatile=True)
else:
self.sample_z_, self.sample_y_ = Variable(self.sample_z_, volatile=True), Variable(self.sample_y_, volatile=True)
self.sample_z_, self.sample_y_ = self.sample_z_.cuda(), self.sample_y_.cuda()

def train(self):
self.train_hist = {}
Expand All @@ -149,26 +146,24 @@ def train(self):
self.train_hist['per_epoch_time'] = []
self.train_hist['total_time'] = []

self.y_real_, self.y_fake_ = torch.ones(self.batch_size, 1), torch.zeros(self.batch_size, 1)
if self.gpu_mode:
self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda())
else:
self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1))
self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()

self.D.train()
print('training start!!')
start_time = time.time()
for epoch in range(self.epoch):
self.G.train()
epoch_start_time = time.time()
for iter in range(len(self.data_X) // self.batch_size):
x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size]
for iter, (x_, y_) in enumerate(self.data_loader):
if iter == self.data_loader.dataset.__len__() // self.batch_size:
break
z_ = torch.rand((self.batch_size, self.z_dim))
y_vec_ = self.data_Y[iter*self.batch_size:(iter+1)*self.batch_size]

y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3).expand(self.batch_size, self.class_num, self.input_size, self.input_size)
if self.gpu_mode:
x_, z_, y_vec_ = Variable(x_.cuda()), Variable(z_.cuda()), Variable(y_vec_.cuda())
else:
x_, z_, y_vec_ = Variable(x_), Variable(z_), Variable(y_vec_)
x_, z_, y_vec_, y_fill_ = x_.cuda(), z_.cuda(), y_vec_.cuda(), y_fill_.cuda()

# update D network
self.D_optimizer.zero_grad()
Expand All @@ -183,7 +178,7 @@ def train(self):
C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1])

D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss
self.train_hist['D_loss'].append(D_loss.data[0])
self.train_hist['D_loss'].append(D_loss.item())

D_loss.backward()
self.D_optimizer.step()
Expand All @@ -198,17 +193,18 @@ def train(self):
C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1])

G_loss += C_fake_loss
self.train_hist['G_loss'].append(G_loss.data[0])
self.train_hist['G_loss'].append(G_loss.item())

G_loss.backward()
self.G_optimizer.step()

if ((iter + 1) % 100) == 0:
print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0]))
((epoch + 1), (iter + 1), self.data_loader.dataset.__len__() // self.batch_size, D_loss.item(), G_loss.item()))

self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
self.visualize_results((epoch+1))
with torch.no_grad():
self.visualize_results((epoch+1))

self.train_hist['total_time'].append(time.time() - start_time)
print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
Expand All @@ -233,16 +229,10 @@ def visualize_results(self, epoch, fix=True):
samples = self.G(self.sample_z_, self.sample_y_)
else:
""" random noise """
temp = torch.LongTensor(self.batch_size, 1).random_() % 10
sample_y_ = torch.FloatTensor(self.batch_size, 10)
sample_y_.zero_()
sample_y_.scatter_(1, temp, 1)
sample_y_ = torch.zeros(self.batch_size, self.class_num).scatter_(1, torch.randint(0, self.class_num - 1, (self.batch_size, 1)).type(torch.LongTensor), 1)
sample_z_ = torch.rand((self.batch_size, self.z_dim))
if self.gpu_mode:
sample_z_, sample_y_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True), \
Variable(sample_y_.cuda(), volatile=True)
else:
sample_z_, sample_y_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True), \
Variable(sample_y_, volatile=True)
sample_z_, sample_y_ = sample_z_.cuda(), sample_y_.cuda()

samples = self.G(sample_z_, sample_y_)

Expand All @@ -251,6 +241,7 @@ def visualize_results(self, epoch, fix=True):
else:
samples = samples.data.numpy().transpose(0, 2, 3, 1)

samples = (samples + 1) / 2
utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png')

Expand Down
Loading

0 comments on commit 4caaace

Please sign in to comment.