diff --git a/ACGAN.py b/ACGAN.py new file mode 100644 index 0000000..920ade5 --- /dev/null +++ b/ACGAN.py @@ -0,0 +1,273 @@ +import utils, torch, time, os, pickle +import numpy as np +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +class generator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(generator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 62 + 10 + self.output_dim = 1 + + self.fc = nn.Sequential( + nn.Linear(self.input_dim, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), + nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input, label): + x = torch.cat([input, label], 1) + x = self.fc(x) + x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) + x = self.deconv(x) + + return x + +class discriminator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S + def __init__(self, dataset='mnist'): + super(discriminator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 1 + self.output_dim = 1 + self.class_num = 10 + + self.conv = nn.Sequential( + nn.Conv2d(self.input_dim, 64, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Conv2d(64, 128, 4, 2, 1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + ) + self.fc1 = nn.Sequential( + nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024), + nn.BatchNorm1d(1024), + nn.LeakyReLU(0.2), + ) + self.dc = nn.Sequential( + nn.Linear(1024, self.output_dim), + nn.Sigmoid(), + ) + self.cl = nn.Sequential( + nn.Linear(1024, self.class_num), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.conv(input) + x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4)) + x = self.fc1(x) + d = self.dc(x) + c = self.cl(x) + + return d, c + +class ACGAN(object): + def __init__(self, args): + # parameters + self.epoch = args.epoch + self.sample_num = 100 + self.batch_size = args.batch_size + self.save_dir = args.save_dir + self.result_dir = args.result_dir + self.dataset = args.dataset + self.log_dir = args.log_dir + self.gpu_mode = args.gpu_mode + self.model_name = args.gan_type + + # networks init + self.G = generator(self.dataset) + self.D = discriminator(self.dataset) + self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) + self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) + + if self.gpu_mode: + self.G.cuda() + self.D.cuda() + self.BCE_loss = nn.BCELoss().cuda() + self.CE_loss = nn.CrossEntropyLoss().cuda() + else: + self.BCE_loss = nn.BCELoss() + self.CE_loss = nn.CrossEntropyLoss() + + print('---------- Networks architecture -------------') + utils.print_network(self.G) + utils.print_network(self.D) + print('-----------------------------------------------') + + # load mnist + self.data_X, self.data_Y = utils.load_mnist(args.dataset) + self.z_dim = 62 + self.y_dim = 10 + + # fixed noise & condition + self.sample_z_ = torch.zeros((self.sample_num, self.z_dim)) + for i in range(10): + self.sample_z_[i*self.y_dim] = torch.rand(1, self.z_dim) + for j in range(1, self.y_dim): + self.sample_z_[i*self.y_dim + j] = self.sample_z_[i*self.y_dim] + + temp = torch.zeros((10, 1)) + for i in range(self.y_dim): + temp[i, 0] = i + + temp_y = torch.zeros((self.sample_num, 1)) + for i in range(10): + temp_y[i*self.y_dim: (i+1)*self.y_dim] = temp + + self.sample_y_ = torch.zeros((self.sample_num, self.y_dim)) + self.sample_y_.scatter_(1, temp_y.type(torch.LongTensor), 1) + if self.gpu_mode: + self.sample_z_, self.sample_y_ = Variable(self.sample_z_.cuda(), volatile=True), Variable(self.sample_y_.cuda(), volatile=True) + else: + self.sample_z_, self.sample_y_ = Variable(self.sample_z_, volatile=True), Variable(self.sample_y_, volatile=True) + + def train(self): + self.train_hist = {} + self.train_hist['D_loss'] = [] + self.train_hist['G_loss'] = [] + self.train_hist['per_epoch_time'] = [] + self.train_hist['total_time'] = [] + + if self.gpu_mode: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) + else: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) + + self.D.train() + print('training start!!') + start_time = time.time() + for epoch in range(self.epoch): + self.G.train() + epoch_start_time = time.time() + for iter in range(len(self.data_X) // self.batch_size): + x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + z_ = torch.rand((self.batch_size, self.z_dim)) + y_vec_ = self.data_Y[iter*self.batch_size:(iter+1)*self.batch_size] + + if self.gpu_mode: + x_, z_, y_vec_ = Variable(x_.cuda()), Variable(z_.cuda()), Variable(y_vec_.cuda()) + else: + x_, z_, y_vec_ = Variable(x_), Variable(z_), Variable(y_vec_) + + # update D network + self.D_optimizer.zero_grad() + + D_real, C_real = self.D(x_) + D_real_loss = self.BCE_loss(D_real, self.y_real_) + C_real_loss = self.CE_loss(C_real, torch.max(y_vec_, 1)[1]) + + G_ = self.G(z_, y_vec_) + D_fake, C_fake = self.D(G_) + D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) + C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1]) + + D_loss = D_real_loss + C_real_loss + D_fake_loss + C_fake_loss + self.train_hist['D_loss'].append(D_loss.data[0]) + + D_loss.backward() + self.D_optimizer.step() + + # update G network + self.G_optimizer.zero_grad() + + G_ = self.G(z_, y_vec_) + D_fake, C_fake = self.D(G_) + + G_loss = self.BCE_loss(D_fake, self.y_real_) + C_fake_loss = self.CE_loss(C_fake, torch.max(y_vec_, 1)[1]) + + G_loss += C_fake_loss + self.train_hist['G_loss'].append(G_loss.data[0]) + + G_loss.backward() + self.G_optimizer.step() + + if ((iter + 1) % 100) == 0: + print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % + ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + + self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) + self.visualize_results((epoch+1)) + + self.train_hist['total_time'].append(time.time() - start_time) + print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), + self.epoch, self.train_hist['total_time'][0])) + print("Training finish!... save training results") + + self.save() + utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, + self.epoch) + utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) + + def visualize_results(self, epoch, fix=True): + self.G.eval() + + if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): + os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) + + image_frame_dim = int(np.floor(np.sqrt(self.sample_num))) + + if fix: + """ fixed noise """ + samples = self.G(self.sample_z_, self.sample_y_) + else: + """ random noise """ + temp = torch.LongTensor(self.batch_size, 1).random_() % 10 + sample_y_ = torch.FloatTensor(self.batch_size, 10) + sample_y_.zero_() + sample_y_.scatter_(1, temp, 1) + if self.gpu_mode: + sample_z_, sample_y_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True), \ + Variable(sample_y_.cuda(), volatile=True) + else: + sample_z_, sample_y_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True), \ + Variable(sample_y_, volatile=True) + + samples = self.G(sample_z_, sample_y_) + + if self.gpu_mode: + samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) + else: + samples = samples.data.numpy().transpose(0, 2, 3, 1) + + utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], + self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') + + def save(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) + torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) + + with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: + pickle.dump(self.train_hist, f) + + def load(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) \ No newline at end of file diff --git a/BEGAN.py b/BEGAN.py new file mode 100644 index 0000000..500e2d1 --- /dev/null +++ b/BEGAN.py @@ -0,0 +1,255 @@ +import utils, torch, time, os, pickle +import numpy as np +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +class generator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(generator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 62 + self.output_dim = 1 + + self.fc = nn.Sequential( + nn.Linear(self.input_dim, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), + nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.fc(input) + x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) + x = self.deconv(x) + + return x + +class discriminator(nn.Module): + # It must be Auto-Encoder style architecture + # Architecture : (64)4c2s-FC32-FC64*14*14_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(discriminator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 1 + self.output_dim = 1 + + self.conv = nn.Sequential( + nn.Conv2d(self.input_dim, 64, 4, 2, 1), + nn.ReLU(), + ) + self.fc = nn.Sequential( + nn.Linear(64 * (self.input_height // 2) * (self.input_width // 2), 32), + nn.BatchNorm1d(32), + nn.ReLU(), + nn.Linear(32, 64 * (self.input_height // 2) * (self.input_width // 2)), + nn.BatchNorm1d(64 * (self.input_height // 2) * (self.input_width // 2)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.conv(input) + x = x.view(x.size()[0], -1) + x = self.fc(x) + x = x.view(-1, 64, (self.input_height // 2), (self.input_width // 2)) + x = self.deconv(x) + + return x + +class BEGAN(object): + def __init__(self, args): + # parameters + self.epoch = args.epoch + self.sample_num = 64 + self.batch_size = args.batch_size + self.save_dir = args.save_dir + self.result_dir = args.result_dir + self.dataset = args.dataset + self.log_dir = args.log_dir + self.gpu_mode = args.gpu_mode + self.model_name = args.gan_type + + # BEGAN parameters + self.gamma = 0.75 + self.lambda_ = 0.001 + self.k = 0. + + # networks init + self.G = generator(self.dataset) + self.D = discriminator(self.dataset) + self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) + self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) + + if self.gpu_mode: + self.G.cuda() + self.D.cuda() + # self.L1_loss = torch.nn.L1loss().cuda() # BEGAN does not work well when using L1loss(). + # else: + # self.L1_loss = torch.nn.L1loss() + + print('---------- Networks architecture -------------') + utils.print_network(self.G) + utils.print_network(self.D) + print('-----------------------------------------------') + + # load mnist + self.data_X, self.data_Y = utils.load_mnist(args.dataset) + self.z_dim = 62 + + # fixed noise + if self.gpu_mode: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + def train(self): + self.train_hist = {} + self.train_hist['D_loss'] = [] + self.train_hist['G_loss'] = [] + self.train_hist['per_epoch_time'] = [] + self.train_hist['total_time'] = [] + + if self.gpu_mode: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) + else: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) + + self.D.train() + print('training start!!') + start_time = time.time() + for epoch in range(self.epoch): + self.G.train() + epoch_start_time = time.time() + for iter in range(len(self.data_X) // self.batch_size): + x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + z_ = torch.rand((self.batch_size, self.z_dim)) + + if self.gpu_mode: + x_, z_ = Variable(x_.cuda()), Variable(z_.cuda()) + else: + x_, z_ = Variable(x_), Variable(z_) + + # update D network + self.D_optimizer.zero_grad() + + D_real = self.D(x_) + D_real_err = torch.mean(torch.abs(D_real - x_)) + + G_ = self.G(z_) + D_fake = self.D(G_) + D_fake_err = torch.mean(torch.abs(D_fake - G_)) + + D_loss = D_real_err - self.k * D_fake_err + self.train_hist['D_loss'].append(D_loss.data[0]) + + D_loss.backward() + self.D_optimizer.step() + + # update G network + self.G_optimizer.zero_grad() + + G_ = self.G(z_) + D_fake = self.D(G_) + D_fake_err = torch.mean(torch.abs(D_fake - G_)) + + G_loss = D_fake_err + self.train_hist['G_loss'].append(G_loss.data[0]) + + G_loss.backward() + self.G_optimizer.step() + + # convergence metric + temp_M = D_real_err + torch.abs(self.gamma * D_real_err - D_fake_err) + + # operation for updating k + temp_k = self.k + self.lambda_ * (self.gamma * D_real_err - D_fake_err) + temp_k = temp_k.data[0] + + # self.k = temp_k.data[0] + self.k = min(max(temp_k, 0), 1) + self.M = temp_M.data[0] + + if ((iter + 1) % 100) == 0: + print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, M: %.8f, k: %.8f" % + ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0], self.M, self.k)) + + self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) + self.visualize_results((epoch+1)) + + self.train_hist['total_time'].append(time.time() - start_time) + print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), + self.epoch, self.train_hist['total_time'][0])) + print("Training finish!... save training results") + + self.save() + utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, + self.epoch) + utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) + + def visualize_results(self, epoch, fix=True): + self.G.eval() + + if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): + os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) + + tot_num_samples = min(self.sample_num, self.batch_size) + image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) + + if fix: + """ fixed noise """ + samples = self.G(self.sample_z_) + else: + """ random noise """ + if self.gpu_mode: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + samples = self.G(sample_z_) + + if self.gpu_mode: + samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) + else: + samples = samples.data.numpy().transpose(0, 2, 3, 1) + + utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], + self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') + + def save(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) + torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) + + with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: + pickle.dump(self.train_hist, f) + + def load(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) \ No newline at end of file diff --git a/CGAN.py b/CGAN.py new file mode 100644 index 0000000..f0ea1d3 --- /dev/null +++ b/CGAN.py @@ -0,0 +1,264 @@ +import utils, torch, time, os, pickle +import numpy as np +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +class generator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(generator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 62 + 10 + self.output_dim = 1 + + self.fc = nn.Sequential( + nn.Linear(self.input_dim, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), + nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input, label): + x = torch.cat([input, label], 1) + x = self.fc(x) + x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) + x = self.deconv(x) + + return x + +class discriminator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S + def __init__(self, dataset = 'mnist'): + super(discriminator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 1 + 10 + self.output_dim = 1 + + self.conv = nn.Sequential( + nn.Conv2d(self.input_dim, 64, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Conv2d(64, 128, 4, 2, 1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + ) + self.fc = nn.Sequential( + nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024), + nn.BatchNorm1d(1024), + nn.LeakyReLU(0.2), + nn.Linear(1024, self.output_dim), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input, label): + x = torch.cat([input, label], 1) + x = self.conv(x) + x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4)) + x = self.fc(x) + + return x + +class CGAN(object): + def __init__(self, args): + # parameters + self.epoch = args.epoch + self.sample_num = 100 + self.batch_size = args.batch_size + self.save_dir = args.save_dir + self.result_dir = args.result_dir + self.dataset = args.dataset + self.log_dir = args.log_dir + self.gpu_mode = args.gpu_mode + self.model_name = args.gan_type + + # networks init + self.G = generator(self.dataset) + self.D = discriminator(self.dataset) + self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) + self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) + + if self.gpu_mode: + self.G.cuda() + self.D.cuda() + self.BCE_loss = nn.BCELoss().cuda() + else: + self.BCE_loss = nn.BCELoss() + + print('---------- Networks architecture -------------') + utils.print_network(self.G) + utils.print_network(self.D) + print('-----------------------------------------------') + + # load mnist + self.data_X, self.data_Y = utils.load_mnist(args.dataset) + self.z_dim = 62 + self.y_dim = 10 + + # fixed noise & condition + self.sample_z_ = torch.zeros((self.sample_num, self.z_dim)) + for i in range(10): + self.sample_z_[i*self.y_dim] = torch.rand(1, self.z_dim) + for j in range(1, self.y_dim): + self.sample_z_[i*self.y_dim + j] = self.sample_z_[i*self.y_dim] + + temp = torch.zeros((10, 1)) + for i in range(self.y_dim): + temp[i, 0] = i + + temp_y = torch.zeros((self.sample_num, 1)) + for i in range(10): + temp_y[i*self.y_dim: (i+1)*self.y_dim] = temp + + self.sample_y_ = torch.zeros((self.sample_num, self.y_dim)) + self.sample_y_.scatter_(1, temp_y.type(torch.LongTensor), 1) + if self.gpu_mode: + self.sample_z_, self.sample_y_ = Variable(self.sample_z_.cuda(), volatile=True), Variable(self.sample_y_.cuda(), volatile=True) + else: + self.sample_z_, self.sample_y_ = Variable(self.sample_z_, volatile=True), Variable(self.sample_y_, volatile=True) + + def train(self): + self.train_hist = {} + self.train_hist['D_loss'] = [] + self.train_hist['G_loss'] = [] + self.train_hist['per_epoch_time'] = [] + self.train_hist['total_time'] = [] + + if self.gpu_mode: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) + else: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) + + self.fill = torch.zeros([10, 10, self.data_X.size()[2], self.data_X.size()[3]]) + for i in range(10): + self.fill[i, i, :, :] = 1 + + self.D.train() + print('training start!!') + start_time = time.time() + for epoch in range(self.epoch): + self.G.train() + epoch_start_time = time.time() + for iter in range(len(self.data_X) // self.batch_size): + x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + z_ = torch.rand((self.batch_size, self.z_dim)) + y_vec_ = self.data_Y[iter*self.batch_size:(iter+1)*self.batch_size] + y_fill_ = self.fill[torch.max(y_vec_, 1)[1].squeeze()] + + if self.gpu_mode: + x_, z_, y_vec_, y_fill_ = Variable(x_.cuda()), Variable(z_.cuda()), \ + Variable(y_vec_.cuda()), Variable(y_fill_.cuda()) + else: + x_, z_, y_vec_, y_fill_ = Variable(x_), Variable(z_), Variable(y_vec_), Variable(y_fill_) + + # update D network + self.D_optimizer.zero_grad() + + D_real = self.D(x_, y_fill_) + D_real_loss = self.BCE_loss(D_real, self.y_real_) + + G_ = self.G(z_, y_vec_) + D_fake = self.D(G_, y_fill_) + D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) + + D_loss = D_real_loss + D_fake_loss + self.train_hist['D_loss'].append(D_loss.data[0]) + + D_loss.backward() + self.D_optimizer.step() + + # update G network + self.G_optimizer.zero_grad() + + G_ = self.G(z_, y_vec_) + D_fake = self.D(G_, y_fill_) + G_loss = self.BCE_loss(D_fake, self.y_real_) + self.train_hist['G_loss'].append(G_loss.data[0]) + + G_loss.backward() + self.G_optimizer.step() + + if ((iter + 1) % 100) == 0: + print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % + ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + + self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) + self.visualize_results((epoch+1)) + + self.train_hist['total_time'].append(time.time() - start_time) + print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), + self.epoch, self.train_hist['total_time'][0])) + print("Training finish!... save training results") + + self.save() + utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, + self.epoch) + utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) + + def visualize_results(self, epoch, fix=True): + self.G.eval() + + if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): + os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) + + image_frame_dim = int(np.floor(np.sqrt(self.sample_num))) + + if fix: + """ fixed noise """ + samples = self.G(self.sample_z_, self.sample_y_) + else: + """ random noise """ + temp = torch.LongTensor(self.batch_size, 1).random_() % 10 + sample_y_ = torch.FloatTensor(self.batch_size, 10) + sample_y_.zero_() + sample_y_.scatter_(1, temp, 1) + if self.gpu_mode: + sample_z_, sample_y_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True), \ + Variable(sample_y_.cuda(), volatile=True) + else: + sample_z_, sample_y_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True), \ + Variable(sample_y_, volatile=True) + + samples = self.G(sample_z_, sample_y_) + + if self.gpu_mode: + samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) + else: + samples = samples.data.numpy().transpose(0, 2, 3, 1) + + utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], + self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') + + def save(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) + torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) + + with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: + pickle.dump(self.train_hist, f) + + def load(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) \ No newline at end of file diff --git a/DRAGAN.py b/DRAGAN.py new file mode 100644 index 0000000..161d941 --- /dev/null +++ b/DRAGAN.py @@ -0,0 +1,246 @@ +import utils, torch, time, os, pickle +import numpy as np +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable, grad +from torch.nn.init import xavier_normal + +class generator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(generator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 62 + self.output_dim = 1 + + self.fc = nn.Sequential( + nn.Linear(self.input_dim, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), + nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.fc(input) + x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) + x = self.deconv(x) + + return x + +class discriminator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S + def __init__(self, dataset = 'mnist'): + super(discriminator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 1 + self.output_dim = 1 + + self.conv = nn.Sequential( + nn.Conv2d(self.input_dim, 64, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Conv2d(64, 128, 4, 2, 1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + ) + self.fc = nn.Sequential( + nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024), + nn.BatchNorm1d(1024), + nn.LeakyReLU(0.2), + nn.Linear(1024, self.output_dim), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.conv(input) + x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4)) + x = self.fc(x) + + return x + +class DRAGAN(object): + def __init__(self, args): + # parameters + self.epoch = args.epoch + self.sample_num = 64 + self.batch_size = args.batch_size + self.save_dir = args.save_dir + self.result_dir = args.result_dir + self.dataset = args.dataset + self.log_dir = args.log_dir + self.gpu_mode = args.gpu_mode + self.model_name = args.gan_type + self.lambda_ = 0.25 + + # networks init + self.G = generator(self.dataset) + self.D = discriminator(self.dataset) + self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) + self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) + + if self.gpu_mode: + self.G.cuda() + self.D.cuda() + self.BCE_loss = nn.BCELoss().cuda() + else: + self.BCE_loss = nn.BCELoss() + + print('---------- Networks architecture -------------') + utils.print_network(self.G) + utils.print_network(self.D) + print('-----------------------------------------------') + + # load mnist + self.data_X, self.data_Y = utils.load_mnist(args.dataset) + self.z_dim = 62 + + # fixed noise + if self.gpu_mode: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + def train(self): + self.train_hist = {} + self.train_hist['D_loss'] = [] + self.train_hist['G_loss'] = [] + self.train_hist['per_epoch_time'] = [] + self.train_hist['total_time'] = [] + + if self.gpu_mode: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) + else: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) + + self.D.train() + print('training start!!') + start_time = time.time() + for epoch in range(self.epoch): + epoch_start_time = time.time() + self.G.train() + for iter in range(len(self.data_X) // self.batch_size): + x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + z_ = torch.rand((self.batch_size, self.z_dim)) + + if self.gpu_mode: + x_, z_ = Variable(x_.cuda()), Variable(z_.cuda()) + else: + x_, z_ = Variable(x_), Variable(z_) + + # update D network + self.D_optimizer.zero_grad() + + D_real = self.D(x_) + D_real_loss = self.BCE_loss(D_real, self.y_real_) + + G_ = self.G(z_) + D_fake = self.D(G_) + D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) + + """ DRAGAN Loss (Gradient penalty) """ + # This is borrowed from https://github.com/jfsantos/dragan-pytorch/blob/master/dragan.py + alpha = torch.rand(x_.size()).cuda() + x_hat = Variable(alpha * x_.data + (1 - alpha) * (x_.data + 0.5 * x_.data.std() * torch.rand(x_.size()).cuda()), + requires_grad=True) + pred_hat = self.D(x_hat) + gradients = grad(outputs=pred_hat, inputs=x_hat, grad_outputs=torch.ones(pred_hat.size()).cuda(), + create_graph=True, retain_graph=True, only_inputs=True)[0] + + # gradients_penalty = self.lambda_ * (gradients.norm(2) - 1) ** 2 # DRAGAN does not work well when using norm(2). + gradient_penalty = self.lambda_ * ((torch.sqrt(torch.sum(torch.sum(torch.sum(gradients**2, 1), 1), 1)) - 1) ** 2).mean() + + D_loss = D_real_loss + D_fake_loss + gradient_penalty + self.train_hist['D_loss'].append(D_loss.data[0]) + D_loss.backward() + self.D_optimizer.step() + + # update G network + self.G_optimizer.zero_grad() + + G_ = self.G(z_) + D_fake = self.D(G_) + + G_loss = self.BCE_loss(D_fake, self.y_real_) + self.train_hist['G_loss'].append(G_loss.data[0]) + + G_loss.backward() + self.G_optimizer.step() + + if ((iter + 1) % 100) == 0: + print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % + ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + + self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) + self.visualize_results((epoch+1)) + + self.train_hist['total_time'].append(time.time() - start_time) + print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), + self.epoch, self.train_hist['total_time'][0])) + print("Training finish!... save training results") + + self.save() + utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, self.epoch) + utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) + + def visualize_results(self, epoch, fix=True): + self.G.eval() + + if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): + os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) + + tot_num_samples = min(self.sample_num, self.batch_size) + image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) + + if fix: + """ fixed noise """ + samples = self.G(self.sample_z_) + else: + """ random noise """ + if self.gpu_mode: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + samples = self.G(sample_z_) + + if self.gpu_mode: + samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) + else: + samples = samples.data.numpy().transpose(0, 2, 3, 1) + + utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], + self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') + + def save(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) + torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) + + with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: + pickle.dump(self.train_hist, f) + + def load(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) \ No newline at end of file diff --git a/EBGAN.py b/EBGAN.py new file mode 100644 index 0000000..26e146b --- /dev/null +++ b/EBGAN.py @@ -0,0 +1,268 @@ +import utils, torch, time, os, pickle +import numpy as np +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +class generator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(generator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 62 + self.output_dim = 1 + + self.fc = nn.Sequential( + nn.Linear(self.input_dim, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), + nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.fc(input) + x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) + x = self.deconv(x) + + return x + +class discriminator(nn.Module): + # It must be Auto-Encoder style architecture + # Architecture : (64)4c2s-FC32-FC64*14*14_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(discriminator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 1 + self.output_dim = 1 + + self.conv = nn.Sequential( + nn.Conv2d(self.input_dim, 64, 4, 2, 1), + nn.ReLU(), + ) + self.code = nn.Sequential( + nn.Linear(64 * (self.input_height // 2) * (self.input_width // 2), 32), # bn and relu are excluded since code is used in pullaway_loss + ) + self.fc = nn.Sequential( + nn.Linear(32, 64 * (self.input_height // 2) * (self.input_width // 2)), + nn.BatchNorm1d(64 * (self.input_height // 2) * (self.input_width // 2)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + # nn.Sigmoid(), # EBGAN does not work well when using Sigmoid(). + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.conv(input) + x = x.view(x.size()[0], -1) + code = self.code(x) + x = self.fc(code) + x = x.view(-1, 64, (self.input_height // 2), (self.input_width // 2)) + x = self.deconv(x) + + return x, code + +class EBGAN(object): + def __init__(self, args): + # parameters + self.epoch = args.epoch + self.sample_num = 64 + self.batch_size = args.batch_size + self.save_dir = args.save_dir + self.result_dir = args.result_dir + self.dataset = args.dataset + self.log_dir = args.log_dir + self.gpu_mode = args.gpu_mode + self.model_name = args.gan_type + + # EBGAN parameters + self.pt_loss_weight = 0.1 + self.margin = max(1, self.batch_size / 64.) # margin for loss function + # usually margin of 1 is enough, but for large batch size it must be larger than 1 + + # networks init + self.G = generator(self.dataset) + self.D = discriminator(self.dataset) + self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) + self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) + + if self.gpu_mode: + self.G.cuda() + self.D.cuda() + self.MSE_loss = nn.MSELoss().cuda() + else: + self.MSE_loss = nn.MSELoss() + + print('---------- Networks architecture -------------') + utils.print_network(self.G) + utils.print_network(self.D) + print('-----------------------------------------------') + + # load mnist + self.data_X, self.data_Y = utils.load_mnist(args.dataset) + self.z_dim = 62 + + # fixed noise + if self.gpu_mode: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + def train(self): + self.train_hist = {} + self.train_hist['D_loss'] = [] + self.train_hist['G_loss'] = [] + self.train_hist['per_epoch_time'] = [] + self.train_hist['total_time'] = [] + + if self.gpu_mode: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) + else: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) + + self.D.train() + print('training start!!') + start_time = time.time() + for epoch in range(self.epoch): + self.G.train() + epoch_start_time = time.time() + for iter in range(len(self.data_X) // self.batch_size): + x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + z_ = torch.rand((self.batch_size, self.z_dim)) + + if self.gpu_mode: + x_, z_ = Variable(x_.cuda()), Variable(z_.cuda()) + else: + x_, z_ = Variable(x_), Variable(z_) + + # update D network + self.D_optimizer.zero_grad() + + D_real, D_real_code = self.D(x_) + D_real_err = self.MSE_loss(D_real, x_) + + G_ = self.G(z_) + D_fake, D_fake_code = self.D(G_) + G_ = Variable(G_.data, requires_grad=False) + D_fake_err = self.MSE_loss(D_fake, G_) + if list(self.margin-D_fake_err.data)[0] > 0: + D_loss = D_real_err + (self.margin - D_fake_err) + else: + D_loss = D_real_err + self.train_hist['D_loss'].append(D_loss.data[0]) + + D_loss.backward() + self.D_optimizer.step() + + # update G network + self.G_optimizer.zero_grad() + + G_ = self.G(z_) + D_fake, D_fake_code = self.D(G_) + G_ = Variable(G_.data, requires_grad=False) + D_fake_err = self.MSE_loss(D_fake, G_) + G_loss = D_fake_err + self.pt_loss_weight * self.pullaway_loss(D_fake_code) + self.train_hist['G_loss'].append(G_loss.data[0]) + + G_loss.backward() + self.G_optimizer.step() + + if ((iter + 1) % 100) == 0: + print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % + ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + + self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) + self.visualize_results((epoch+1)) + + self.train_hist['total_time'].append(time.time() - start_time) + print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), + self.epoch, self.train_hist['total_time'][0])) + print("Training finish!... save training results") + + self.save() + utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, + self.epoch) + utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) + + def pullaway_loss(self, embeddings): + """ pullaway_loss tensorflow version code + + norm = tf.sqrt(tf.reduce_sum(tf.square(embeddings), 1, keep_dims=True)) + normalized_embeddings = embeddings / norm + similarity = tf.matmul( + normalized_embeddings, normalized_embeddings, transpose_b=True) + batch_size = tf.cast(tf.shape(embeddings)[0], tf.float32) + pt_loss = (tf.reduce_sum(similarity) - batch_size) / (batch_size * (batch_size - 1)) + return pt_loss + + """ + norm = torch.sqrt(torch.sum(embeddings ** 2, 1, keepdim=True)) + normalized_embeddings = embeddings / norm + similarity = torch.matmul(normalized_embeddings, normalized_embeddings.transpose(1, 0)) + batch_size = embeddings.size()[0] + pt_loss = (torch.sum(similarity) - batch_size) / (batch_size * (batch_size - 1)) + return pt_loss + + + def visualize_results(self, epoch, fix=True): + self.G.eval() + + if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): + os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) + + tot_num_samples = min(self.sample_num, self.batch_size) + image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) + + if fix: + """ fixed noise """ + samples = self.G(self.sample_z_) + else: + """ random noise """ + if self.gpu_mode: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + samples = self.G(sample_z_) + + if self.gpu_mode: + samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) + else: + samples = samples.data.numpy().transpose(0, 2, 3, 1) + + utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], + self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') + + def save(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) + torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) + + with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: + pickle.dump(self.train_hist, f) + + def load(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) \ No newline at end of file diff --git a/GAN.py b/GAN.py new file mode 100644 index 0000000..2fe40e1 --- /dev/null +++ b/GAN.py @@ -0,0 +1,233 @@ +import utils, torch, time, os, pickle +import numpy as np +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +class generator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(generator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 62 + self.output_dim = 1 + + self.fc = nn.Sequential( + nn.Linear(self.input_dim, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), + nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.fc(input) + x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) + x = self.deconv(x) + + return x + +class discriminator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S + def __init__(self, dataset = 'mnist'): + super(discriminator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 1 + self.output_dim = 1 + + self.conv = nn.Sequential( + nn.Conv2d(self.input_dim, 64, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Conv2d(64, 128, 4, 2, 1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + ) + self.fc = nn.Sequential( + nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024), + nn.BatchNorm1d(1024), + nn.LeakyReLU(0.2), + nn.Linear(1024, self.output_dim), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.conv(input) + x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4)) + x = self.fc(x) + + return x + +class GAN(object): + def __init__(self, args): + # parameters + self.epoch = args.epoch + self.sample_num = 64 + self.batch_size = args.batch_size + self.save_dir = args.save_dir + self.result_dir = args.result_dir + self.dataset = args.dataset + self.log_dir = args.log_dir + self.gpu_mode = args.gpu_mode + self.model_name = args.gan_type + + # networks init + self.G = generator(self.dataset) + self.D = discriminator(self.dataset) + self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) + self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) + + if self.gpu_mode: + self.G.cuda() + self.D.cuda() + self.BCE_loss = nn.BCELoss().cuda() + else: + self.BCE_loss = nn.BCELoss() + + print('---------- Networks architecture -------------') + utils.print_network(self.G) + utils.print_network(self.D) + print('-----------------------------------------------') + + # load mnist + self.data_X, self.data_Y = utils.load_mnist(args.dataset) + self.z_dim = 62 + + # fixed noise + if self.gpu_mode: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + def train(self): + self.train_hist = {} + self.train_hist['D_loss'] = [] + self.train_hist['G_loss'] = [] + self.train_hist['per_epoch_time'] = [] + self.train_hist['total_time'] = [] + + if self.gpu_mode: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) + else: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) + + self.D.train() + print('training start!!') + start_time = time.time() + for epoch in range(self.epoch): + self.G.train() + epoch_start_time = time.time() + for iter in range(len(self.data_X) // self.batch_size): + x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + z_ = torch.rand((self.batch_size, self.z_dim)) + + if self.gpu_mode: + x_, z_ = Variable(x_.cuda()), Variable(z_.cuda()) + else: + x_, z_ = Variable(x_), Variable(z_) + + # update D network + self.D_optimizer.zero_grad() + + D_real = self.D(x_) + D_real_loss = self.BCE_loss(D_real, self.y_real_) + + G_ = self.G(z_) + D_fake = self.D(G_) + D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) + + D_loss = D_real_loss + D_fake_loss + self.train_hist['D_loss'].append(D_loss.data[0]) + + D_loss.backward() + self.D_optimizer.step() + + # update G network + self.G_optimizer.zero_grad() + + G_ = self.G(z_) + D_fake = self.D(G_) + G_loss = self.BCE_loss(D_fake, self.y_real_) + self.train_hist['G_loss'].append(G_loss.data[0]) + + G_loss.backward() + self.G_optimizer.step() + + if ((iter + 1) % 100) == 0: + print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % + ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + + self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) + self.visualize_results((epoch+1)) + + self.train_hist['total_time'].append(time.time() - start_time) + print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), + self.epoch, self.train_hist['total_time'][0])) + print("Training finish!... save training results") + + self.save() + utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, + self.epoch) + utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) + + def visualize_results(self, epoch, fix=True): + self.G.eval() + + if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): + os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) + + tot_num_samples = min(self.sample_num, self.batch_size) + image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) + + if fix: + """ fixed noise """ + samples = self.G(self.sample_z_) + else: + """ random noise """ + if self.gpu_mode: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + samples = self.G(sample_z_) + + if self.gpu_mode: + samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) + else: + samples = samples.data.numpy().transpose(0, 2, 3, 1) + + utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], + self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') + + def save(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) + torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) + + with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: + pickle.dump(self.train_hist, f) + + def load(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) \ No newline at end of file diff --git a/LSGAN.py b/LSGAN.py new file mode 100644 index 0000000..4e420b3 --- /dev/null +++ b/LSGAN.py @@ -0,0 +1,233 @@ +import utils, torch, time, os, pickle +import numpy as np +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +class generator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(generator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 62 + self.output_dim = 1 + + self.fc = nn.Sequential( + nn.Linear(self.input_dim, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), + nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.fc(input) + x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) + x = self.deconv(x) + + return x + +class discriminator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S + def __init__(self, dataset = 'mnist'): + super(discriminator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 1 + self.output_dim = 1 + + self.conv = nn.Sequential( + nn.Conv2d(self.input_dim, 64, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Conv2d(64, 128, 4, 2, 1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + ) + self.fc = nn.Sequential( + nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024), + nn.BatchNorm1d(1024), + nn.LeakyReLU(0.2), + nn.Linear(1024, self.output_dim), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.conv(input) + x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4)) + x = self.fc(x) + + return x + +class LSGAN(object): + def __init__(self, args): + # parameters + self.epoch = args.epoch + self.sample_num = 64 + self.batch_size = args.batch_size + self.save_dir = args.save_dir + self.result_dir = args.result_dir + self.dataset = args.dataset + self.log_dir = args.log_dir + self.gpu_mode = args.gpu_mode + self.model_name = args.gan_type + + # networks init + self.G = generator(self.dataset) + self.D = discriminator(self.dataset) + self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) + self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) + + if self.gpu_mode: + self.G.cuda() + self.D.cuda() + self.MSE_loss = nn.MSELoss().cuda() + else: + self.MSE_loss = nn.MSELoss() + + print('---------- Networks architecture -------------') + utils.print_network(self.G) + utils.print_network(self.D) + print('-----------------------------------------------') + + # load mnist + self.data_X, self.data_Y = utils.load_mnist(args.dataset) + self.z_dim = 62 + + # fixed noise + if self.gpu_mode: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + def train(self): + self.train_hist = {} + self.train_hist['D_loss'] = [] + self.train_hist['G_loss'] = [] + self.train_hist['per_epoch_time'] = [] + self.train_hist['total_time'] = [] + + if self.gpu_mode: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) + else: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) + + self.D.train() + print('training start!!') + start_time = time.time() + for epoch in range(self.epoch): + self.G.train() + epoch_start_time = time.time() + for iter in range(len(self.data_X) // self.batch_size): + x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + z_ = torch.rand((self.batch_size, self.z_dim)) + + if self.gpu_mode: + x_, z_ = Variable(x_.cuda()), Variable(z_.cuda()) + else: + x_, z_ = Variable(x_), Variable(z_) + + # update D network + self.D_optimizer.zero_grad() + + D_real = self.D(x_) + D_real_loss = self.MSE_loss(D_real, self.y_real_) + + G_ = self.G(z_) + D_fake = self.D(G_) + D_fake_loss = self.MSE_loss(D_fake, self.y_fake_) + + D_loss = D_real_loss + D_fake_loss + self.train_hist['D_loss'].append(D_loss.data[0]) + + D_loss.backward() + self.D_optimizer.step() + + # update G network + self.G_optimizer.zero_grad() + + G_ = self.G(z_) + D_fake = self.D(G_) + G_loss = self.MSE_loss(D_fake, self.y_real_) + self.train_hist['G_loss'].append(G_loss.data[0]) + + G_loss.backward() + self.G_optimizer.step() + + if ((iter + 1) % 100) == 0: + print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % + ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + + self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) + self.visualize_results((epoch+1)) + + self.train_hist['total_time'].append(time.time() - start_time) + print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), + self.epoch, self.train_hist['total_time'][0])) + print("Training finish!... save training results") + + self.save() + utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, + self.epoch) + utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) + + def visualize_results(self, epoch, fix=True): + self.G.eval() + + if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): + os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) + + tot_num_samples = min(self.sample_num, self.batch_size) + image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) + + if fix: + """ fixed noise """ + samples = self.G(self.sample_z_) + else: + """ random noise """ + if self.gpu_mode: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + samples = self.G(sample_z_) + + if self.gpu_mode: + samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) + else: + samples = samples.data.numpy().transpose(0, 2, 3, 1) + + utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], + self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') + + def save(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) + torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) + + with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: + pickle.dump(self.train_hist, f) + + def load(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) \ No newline at end of file diff --git a/WGAN.py b/WGAN.py new file mode 100644 index 0000000..15855a0 --- /dev/null +++ b/WGAN.py @@ -0,0 +1,241 @@ +import utils, torch, time, os, pickle, random +import numpy as np +import torch.nn as nn +import torch.optim as optim +from torch.autograd import Variable + +class generator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(generator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 62 + self.output_dim = 1 + + self.fc = nn.Sequential( + nn.Linear(self.input_dim, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), + nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.fc(input) + x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) + x = self.deconv(x) + + return x + +class discriminator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S + def __init__(self, dataset = 'mnist'): + super(discriminator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 1 + self.output_dim = 1 + + self.conv = nn.Sequential( + nn.Conv2d(self.input_dim, 64, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Conv2d(64, 128, 4, 2, 1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + ) + self.fc = nn.Sequential( + nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024), + nn.BatchNorm1d(1024), + nn.LeakyReLU(0.2), + nn.Linear(1024, self.output_dim), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.conv(input) + x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4)) + x = self.fc(x) + + return x + +class WGAN(object): + def __init__(self, args): + # parameters + self.epoch = args.epoch + self.sample_num = 64 + self.batch_size = args.batch_size + self.save_dir = args.save_dir + self.result_dir = args.result_dir + self.dataset = args.dataset + self.log_dir = args.log_dir + self.gpu_mode = args.gpu_mode + self.model_name = args.gan_type + self.c = 0.01 # clipping value + self.n_critic = 5 # the number of iterations of the critic per generator iteration + + # networks init + self.G = generator(self.dataset) + self.D = discriminator(self.dataset) + self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) + self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) + + if self.gpu_mode: + self.G.cuda() + self.D.cuda() + + print('---------- Networks architecture -------------') + utils.print_network(self.G) + utils.print_network(self.D) + print('-----------------------------------------------') + + # load mnist + self.data_X, self.data_Y = utils.load_mnist(args.dataset) + self.z_dim = 62 + + # fixed noise + if self.gpu_mode: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + self.sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + def train(self): + self.train_hist = {} + self.train_hist['D_loss'] = [] + self.train_hist['G_loss'] = [] + self.train_hist['per_epoch_time'] = [] + self.train_hist['total_time'] = [] + + if self.gpu_mode: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) + else: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) + + self.D.train() + print('training start!!') + start_time = time.time() + for epoch in range(self.epoch): + self.G.train() + epoch_start_time = time.time() + for iter in range(len(self.data_X) // self.batch_size): + # update D network + D_losses = [] + for _ in range(self.n_critic): + inds = random.sample(range(0, len(self.data_X)), self.batch_size) + x_ = self.data_X[inds, :, :, :] + z_ = torch.rand((self.batch_size, self.z_dim)) + + if self.gpu_mode: + x_, z_ = Variable(x_.cuda()), Variable(z_.cuda()) + else: + x_, z_ = Variable(x_), Variable(z_) + + self.D_optimizer.zero_grad() + + D_real = self.D(x_) + D_real_loss = torch.mean(D_real) + + G_ = self.G(z_) + D_fake = self.D(G_) + D_fake_loss = torch.mean(D_fake) + + D_loss = D_real_loss - D_fake_loss + D_losses.append(D_loss.data[0]) + + D_loss.backward() + self.D_optimizer.step() + + # clipping D + for p in self.D.parameters(): + p.data.clamp_(-self.c, self.c) + + self.train_hist['D_loss'].append(np.mean(D_losses)) + + # update G network + self.G_optimizer.zero_grad() + + G_ = self.G(z_) + D_fake = self.D(G_) + G_loss = torch.mean(D_fake) + self.train_hist['G_loss'].append(G_loss.data[0]) + + G_loss.backward() + self.G_optimizer.step() + + if ((iter + 1) % 100) == 0: + print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" % + ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0])) + + self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) + self.visualize_results((epoch+1)) + + self.train_hist['total_time'].append(time.time() - start_time) + print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), + self.epoch, self.train_hist['total_time'][0])) + print("Training finish!... save training results") + + self.save() + utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, + self.epoch) + utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) + + def visualize_results(self, epoch, fix=True): + self.G.eval() + + if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): + os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) + + tot_num_samples = min(self.sample_num, self.batch_size) + image_frame_dim = int(np.floor(np.sqrt(tot_num_samples))) + + if fix: + """ fixed noise """ + samples = self.G(self.sample_z_) + else: + """ random noise """ + if self.gpu_mode: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)).cuda(), volatile=True) + else: + sample_z_ = Variable(torch.rand((self.batch_size, self.z_dim)), volatile=True) + + samples = self.G(sample_z_) + + if self.gpu_mode: + samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) + else: + samples = samples.data.numpy().transpose(0, 2, 3, 1) + + utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], + self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') + + def save(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) + torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) + + with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: + pickle.dump(self.train_hist, f) + + def load(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) \ No newline at end of file diff --git a/infoGAN.py b/infoGAN.py new file mode 100644 index 0000000..2ebd30d --- /dev/null +++ b/infoGAN.py @@ -0,0 +1,336 @@ +import utils, torch, time, os, pickle, itertools +import numpy as np +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import matplotlib.pyplot as plt +from torch.autograd import Variable + +class generator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S + def __init__(self, dataset = 'mnist'): + super(generator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 62 + 12 + self.output_dim = 1 + + self.fc = nn.Sequential( + nn.Linear(self.input_dim, 1024), + nn.BatchNorm1d(1024), + nn.ReLU(), + nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)), + nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)), + nn.ReLU(), + ) + self.deconv = nn.Sequential( + nn.ConvTranspose2d(128, 64, 4, 2, 1), + nn.BatchNorm2d(64), + nn.ReLU(), + nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input, cont_code, dist_code): + x = torch.cat([input, cont_code, dist_code], 1) + x = self.fc(x) + x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4)) + x = self.deconv(x) + + return x + +class discriminator(nn.Module): + # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657) + # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S + def __init__(self, dataset='mnist'): + super(discriminator, self).__init__() + if dataset == 'mnist' or 'fashion-mnist': + self.input_height = 28 + self.input_width = 28 + self.input_dim = 1 + self.output_dim = 1 + self.len_discrete_code = 10 # categorical distribution (i.e. label) + self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness) + + self.conv = nn.Sequential( + nn.Conv2d(self.input_dim, 64, 4, 2, 1), + nn.LeakyReLU(0.2), + nn.Conv2d(64, 128, 4, 2, 1), + nn.BatchNorm2d(128), + nn.LeakyReLU(0.2), + ) + self.fc = nn.Sequential( + nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024), + nn.BatchNorm1d(1024), + nn.LeakyReLU(0.2), + nn.Linear(1024, self.output_dim + self.len_continuous_code + self.len_discrete_code), + nn.Sigmoid(), + ) + utils.initialize_weights(self) + + def forward(self, input): + x = self.conv(input) + x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4)) + x = self.fc(x) + a = F.sigmoid(x[:, self.output_dim]) + b = x[:, self.output_dim:self.output_dim + self.len_continuous_code] + c = x[:, self.output_dim + self.len_continuous_code:] + + return a, b, c + +class infoGAN(object): + def __init__(self, args, SUPERVISED=True): + # parameters + self.epoch = args.epoch + self.sample_num = 100 + self.batch_size = args.batch_size + self.save_dir = args.save_dir + self.result_dir = args.result_dir + self.dataset = args.dataset + self.log_dir = args.log_dir + self.gpu_mode = args.gpu_mode + self.model_name = args.gan_type + self.SUPERVISED = SUPERVISED # if it is true, label info is directly used for code + self.len_discrete_code = 10 # categorical distribution (i.e. label) + self.len_continuous_code = 2 # gaussian distribution (e.g. rotation, thickness) + + # networks init + self.G = generator(self.dataset) + self.D = discriminator(self.dataset) + self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2)) + self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2)) + self.info_optimizer = optim.Adam(itertools.chain(self.G.parameters(), self.D.parameters()), lr=args.lrD, betas=(args.beta1, args.beta2)) + + if self.gpu_mode: + self.G.cuda() + self.D.cuda() + self.BCE_loss = nn.BCELoss().cuda() + self.CE_loss = nn.CrossEntropyLoss().cuda() + self.MSE_loss = nn.MSELoss().cuda() + else: + self.BCE_loss = nn.BCELoss() + self.CE_loss = nn.CrossEntropyLoss() + self.MSE_loss = nn.MSELoss() + + print('---------- Networks architecture -------------') + utils.print_network(self.G) + utils.print_network(self.D) + print('-----------------------------------------------') + + # load mnist + self.data_X, self.data_Y = utils.load_mnist(args.dataset) + self.z_dim = 62 + self.y_dim = 10 + + # fixed noise & condition + self.sample_z_ = torch.zeros((self.sample_num, self.z_dim)) + for i in range(10): + self.sample_z_[i*self.y_dim] = torch.rand(1, self.z_dim) + for j in range(1, self.y_dim): + self.sample_z_[i*self.y_dim + j] = self.sample_z_[i*self.y_dim] + + temp = torch.zeros((10, 1)) + for i in range(self.y_dim): + temp[i, 0] = i + + temp_y = torch.zeros((self.sample_num, 1)) + for i in range(10): + temp_y[i*self.y_dim: (i+1)*self.y_dim] = temp + + self.sample_y_ = torch.zeros((self.sample_num, self.y_dim)) + self.sample_y_.scatter_(1, temp_y.type(torch.LongTensor), 1) + self.sample_c_ = torch.zeros((self.sample_num, self.len_continuous_code)) + + # manipulating two continuous code + temp_z_ = torch.rand((1, self.z_dim)) + self.sample_z2_ = temp_z_ + for i in range(self.sample_num - 1): + self.sample_z2_ = torch.cat([self.sample_z2_, temp_z_]) + + y = np.zeros(self.sample_num, dtype=np.int64) + y_one_hot = np.zeros((self.sample_num, self.len_discrete_code)) + y_one_hot[np.arange(self.sample_num), y] = 1 + self.sample_y2_ = torch.from_numpy(y_one_hot).type(torch.FloatTensor) + + temp_c = torch.linspace(-1, 1, 10) + self.sample_c2_ = torch.zeros((self.sample_num, 2)) + for i in range(10): + for j in range(10): + self.sample_c2_[i*10+j, 0] = temp_c[i] + self.sample_c2_[i*10+j, 1] = temp_c[j] + + if self.gpu_mode: + self.sample_z_, self.sample_y_, self.sample_c_, self.sample_z2_, self.sample_y2_, self.sample_c2_ = \ + Variable(self.sample_z_.cuda(), volatile=True), Variable(self.sample_y_.cuda(), volatile=True), \ + Variable(self.sample_c_.cuda(), volatile=True), Variable(self.sample_z2_.cuda(), volatile=True), \ + Variable(self.sample_y2_.cuda(), volatile=True), Variable(self.sample_c2_.cuda(), volatile=True) + else: + self.sample_z_, self.sample_y_, self.sample_c_, self.sample_z2_, self.sample_y2_, self.sample_c2_ = \ + Variable(self.sample_z_, volatile=True), Variable(self.sample_y_, volatile=True), \ + Variable(self.sample_c_, volatile=True), Variable(self.sample_z2_, volatile=True), \ + Variable(self.sample_y2_, volatile=True), Variable(self.sample_c2_, volatile=True) + + def train(self): + self.train_hist = {} + self.train_hist['D_loss'] = [] + self.train_hist['G_loss'] = [] + self.train_hist['info_loss'] = [] + self.train_hist['per_epoch_time'] = [] + self.train_hist['total_time'] = [] + + if self.gpu_mode: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1).cuda()), Variable(torch.zeros(self.batch_size, 1).cuda()) + else: + self.y_real_, self.y_fake_ = Variable(torch.ones(self.batch_size, 1)), Variable(torch.zeros(self.batch_size, 1)) + + self.D.train() + print('training start!!') + start_time = time.time() + for epoch in range(self.epoch): + self.G.train() + epoch_start_time = time.time() + for iter in range(len(self.data_X) // self.batch_size): + x_ = self.data_X[iter*self.batch_size:(iter+1)*self.batch_size] + z_ = torch.rand((self.batch_size, self.z_dim)) + if self.SUPERVISED == True: + y_disc_ = self.data_Y[iter*self.batch_size:(iter+1)*self.batch_size] + else: + y_disc_ = torch.from_numpy( + np.random.multinomial(1, self.len_discrete_code * [float(1.0 / self.len_discrete_code)], + size=[self.batch_size])).type(torch.FloatTensor) + + y_cont_ = torch.from_numpy(np.random.uniform(-1, 1, size=(self.batch_size, 2))).type(torch.FloatTensor) + + if self.gpu_mode: + x_, z_, y_disc_, y_cont_ = Variable(x_.cuda()), Variable(z_.cuda()), \ + Variable(y_disc_.cuda()), Variable(y_cont_.cuda()) + else: + x_, z_, y_disc_, y_cont_ = Variable(x_), Variable(z_), Variable(y_disc_), Variable(y_cont_) + + # update D network + self.D_optimizer.zero_grad() + + D_real, _, _ = self.D(x_) + D_real_loss = self.BCE_loss(D_real, self.y_real_) + + G_ = self.G(z_, y_cont_, y_disc_) + D_fake, _, _ = self.D(G_) + D_fake_loss = self.BCE_loss(D_fake, self.y_fake_) + + D_loss = D_real_loss + D_fake_loss + self.train_hist['D_loss'].append(D_loss.data[0]) + + D_loss.backward(retain_graph=True) + self.D_optimizer.step() + + # update G network + self.G_optimizer.zero_grad() + + G_ = self.G(z_, y_cont_, y_disc_) + D_fake, D_cont, D_disc = self.D(G_) + + G_loss = self.BCE_loss(D_fake, self.y_real_) + self.train_hist['G_loss'].append(G_loss.data[0]) + + G_loss.backward(retain_graph=True) + self.G_optimizer.step() + + # information loss + disc_loss = self.CE_loss(D_disc, torch.max(y_disc_, 1)[1]) + cont_loss = self.MSE_loss(D_cont, y_cont_) + info_loss = disc_loss + cont_loss + self.train_hist['info_loss'].append(info_loss.data[0]) + + info_loss.backward() + self.info_optimizer.step() + + + if ((iter + 1) % 100) == 0: + print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f, info_loss: %.8f" % + ((epoch + 1), (iter + 1), len(self.data_X) // self.batch_size, D_loss.data[0], G_loss.data[0], info_loss.data[0])) + + self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time) + self.visualize_results((epoch+1)) + + self.train_hist['total_time'].append(time.time() - start_time) + print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']), + self.epoch, self.train_hist['total_time'][0])) + print("Training finish!... save training results") + + self.save() + utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name, + self.epoch) + utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_cont', + self.epoch) + self.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name) + + def visualize_results(self, epoch): + self.G.eval() + + if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name): + os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name) + + image_frame_dim = int(np.floor(np.sqrt(self.sample_num))) + + """ style by class """ + samples = self.G(self.sample_z_, self.sample_c_, self.sample_y_) + if self.gpu_mode: + samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) + else: + samples = samples.data.numpy().transpose(0, 2, 3, 1) + + utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], + self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%03d' % epoch + '.png') + + """ manipulating two continous codes """ + samples = self.G(self.sample_z2_, self.sample_c2_, self.sample_y2_) + if self.gpu_mode: + samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1) + else: + samples = samples.data.numpy().transpose(0, 2, 3, 1) + + utils.save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim], + self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_cont_epoch%03d' % epoch + '.png') + + def save(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + if not os.path.exists(save_dir): + os.makedirs(save_dir) + + torch.save(self.G.state_dict(), os.path.join(save_dir, self.model_name + '_G.pkl')) + torch.save(self.D.state_dict(), os.path.join(save_dir, self.model_name + '_D.pkl')) + + with open(os.path.join(save_dir, self.model_name + '_history.pkl'), 'wb') as f: + pickle.dump(self.train_hist, f) + + def load(self): + save_dir = os.path.join(self.save_dir, self.dataset, self.model_name) + + self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl'))) + self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl'))) + + def loss_plot(self, hist, path='Train_hist.png', model_name=''): + x = range(len(hist['D_loss'])) + + y1 = hist['D_loss'] + y2 = hist['G_loss'] + y3 = hist['info_loss'] + + plt.plot(x, y1, label='D_loss') + plt.plot(x, y2, label='G_loss') + plt.plot(x, y3, label='info_loss') + + plt.xlabel('Iter') + plt.ylabel('Loss') + + plt.legend(loc=4) + plt.grid(True) + plt.tight_layout() + + path = os.path.join(path, model_name + '_loss.png') + + plt.savefig(path) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..ffdbe83 --- /dev/null +++ b/main.py @@ -0,0 +1,104 @@ +import argparse, os +from GAN import GAN +from CGAN import CGAN +from LSGAN import LSGAN +from DRAGAN import DRAGAN +from ACGAN import ACGAN +from WGAN import WGAN +from infoGAN import infoGAN +from EBGAN import EBGAN +from BEGAN import BEGAN + +"""parsing and configuration""" +def parse_args(): + desc = "Pytorch implementation of GAN collections" + parser = argparse.ArgumentParser(description=desc) + + parser.add_argument('--gan_type', type=str, default='EBGAN', + choices=['GAN', 'CGAN', 'infoGAN', 'ACGAN', 'EBGAN', 'BEGAN', 'WGAN', 'DRAGAN', 'LSGAN'], + help='The type of GAN')#, required=True) + parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist', 'fashion-mnist', 'celebA'], + help='The name of dataset') + parser.add_argument('--epoch', type=int, default=25, help='The number of epochs to run') + parser.add_argument('--batch_size', type=int, default=64, help='The size of batch') + parser.add_argument('--save_dir', type=str, default='models', + help='Directory name to save the model') + parser.add_argument('--result_dir', type=str, default='results', + help='Directory name to save the generated images') + parser.add_argument('--log_dir', type=str, default='logs', + help='Directory name to save training logs') + parser.add_argument('--lrG', type=float, default=0.0002) + parser.add_argument('--lrD', type=float, default=0.0002) + parser.add_argument('--beta1', type=float, default=0.5) + parser.add_argument('--beta2', type=float, default=0.999) + parser.add_argument('--gpu_mode', type=bool, default=True) + + return check_args(parser.parse_args()) + +"""checking arguments""" +def check_args(args): + # --save_dir + if not os.path.exists(args.save_dir): + os.makedirs(args.save_dir) + + # --result_dir + if not os.path.exists(args.result_dir): + os.makedirs(args.result_dir) + + # --result_dir + if not os.path.exists(args.log_dir): + os.makedirs(args.log_dir) + + # --epoch + try: + assert args.epoch >= 1 + except: + print('number of epochs must be larger than or equal to one') + + # --batch_size + try: + assert args.batch_size >= 1 + except: + print('batch size must be larger than or equal to one') + + return args + +"""main""" +def main(): + # parse arguments + args = parse_args() + if args is None: + exit() + + # declare instance for GAN + if args.gan_type == 'GAN': + gan = GAN(args) + elif args.gan_type == 'CGAN': + gan = CGAN(args) + elif args.gan_type == 'ACGAN': + gan = ACGAN(args) + elif args.gan_type == 'infoGAN': + gan = infoGAN(args, SUPERVISED = True) + elif args.gan_type == 'EBGAN': + gan = EBGAN(args) + elif args.gan_type == 'WGAN': + gan = WGAN(args) + elif args.gan_type == 'DRAGAN': + gan = DRAGAN(args) + elif args.gan_type == 'LSGAN': + gan = LSGAN(args) + elif args.gan_type == 'BEGAN': + gan = BEGAN(args) + else: + raise Exception("[!] There is no option for " + args.gan_type) + + # launch the graph in a session + gan.train() + print(" [*] Training finished!") + + # visualize learned generator + gan.visualize_results(args.epoch) + print(" [*] Testing finished!") + +if __name__ == '__main__': + main() diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..d3c97b4 --- /dev/null +++ b/utils.py @@ -0,0 +1,124 @@ +import os, gzip, torch +import torch.nn as nn +import numpy as np +import scipy.misc +import imageio +import matplotlib.pyplot as plt + +def load_mnist(dataset): + data_dir = os.path.join("./data", dataset) + + def extract_data(filename, num_data, head_size, data_size): + with gzip.open(filename) as bytestream: + bytestream.read(head_size) + buf = bytestream.read(data_size * num_data) + data = np.frombuffer(buf, dtype=np.uint8).astype(np.float) + return data + + data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28) + trX = data.reshape((60000, 28, 28, 1)) + + data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1) + trY = data.reshape((60000)) + + data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28) + teX = data.reshape((10000, 28, 28, 1)) + + data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1) + teY = data.reshape((10000)) + + trY = np.asarray(trY).astype(np.int) + teY = np.asarray(teY) + + X = np.concatenate((trX, teX), axis=0) + y = np.concatenate((trY, teY), axis=0).astype(np.int) + + seed = 547 + np.random.seed(seed) + np.random.shuffle(X) + np.random.seed(seed) + np.random.shuffle(y) + + y_vec = np.zeros((len(y), 10), dtype=np.float) + for i, label in enumerate(y): + y_vec[i, y[i]] = 1 + + X = X.transpose(0, 3, 1, 2) / 255. + # y_vec = y_vec.transpose(0, 3, 1, 2) + + X = torch.from_numpy(X).type(torch.FloatTensor) + y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor) + return X, y_vec + +def print_network(net): + num_params = 0 + for param in net.parameters(): + num_params += param.numel() + print(net) + print('Total number of parameters: %d' % num_params) + +def save_images(images, size, image_path): + return imsave(images, size, image_path) + +def imsave(images, size, path): + image = np.squeeze(merge(images, size)) + return scipy.misc.imsave(path, image) + +def merge(images, size): + h, w = images.shape[1], images.shape[2] + if (images.shape[3] in (3,4)): + c = images.shape[3] + img = np.zeros((h * size[0], w * size[1], c)) + for idx, image in enumerate(images): + i = idx % size[1] + j = idx // size[1] + img[j * h:j * h + h, i * w:i * w + w, :] = image + return img + elif images.shape[3]==1: + img = np.zeros((h * size[0], w * size[1])) + for idx, image in enumerate(images): + i = idx % size[1] + j = idx // size[1] + img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0] + return img + else: + raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4') + +def generate_animation(path, num): + images = [] + for e in range(num): + img_name = path + '_epoch%03d' % (e+1) + '.png' + images.append(imageio.imread(img_name)) + imageio.mimsave(path + '_generate_animation.gif', images, fps=5) + +def loss_plot(hist, path = 'Train_hist.png', model_name = ''): + x = range(len(hist['D_loss'])) + + y1 = hist['D_loss'] + y2 = hist['G_loss'] + + plt.plot(x, y1, label='D_loss') + plt.plot(x, y2, label='G_loss') + + plt.xlabel('Iter') + plt.ylabel('Loss') + + plt.legend(loc=4) + plt.grid(True) + plt.tight_layout() + + path = os.path.join(path, model_name + '_loss.png') + + plt.savefig(path) + +def initialize_weights(net): + for m in net.modules(): + if isinstance(m, nn.Conv2d): + m.weight.data.normal_(0, 0.02) + m.bias.data.zero_() + elif isinstance(m, nn.ConvTranspose2d): + m.weight.data.normal_(0, 0.02) + m.bias.data.zero_() + elif isinstance(m, nn.Linear): + m.weight.data.normal_(0, 0.02) + m.bias.data.zero_() \ No newline at end of file