Skip to content

Commit

Permalink
Update to Pytorch 0.4
Browse files Browse the repository at this point in the history
  • Loading branch information
znxlwm authored Jun 12, 2018
1 parent c19ab22 commit 0d183bb
Showing 1 changed file with 2 additions and 3 deletions.
5 changes: 2 additions & 3 deletions ACGAN.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,9 +161,8 @@ def train(self):
break
z_ = torch.rand((self.batch_size, self.z_dim))
y_vec_ = torch.zeros((self.batch_size, self.class_num)).scatter_(1, y_.type(torch.LongTensor).unsqueeze(1), 1)
y_fill_ = y_vec_.unsqueeze(2).unsqueeze(3).expand(self.batch_size, self.class_num, self.input_size, self.input_size)
if self.gpu_mode:
x_, z_, y_vec_, y_fill_ = x_.cuda(), z_.cuda(), y_vec_.cuda(), y_fill_.cuda()
x_, z_, y_vec_ = x_.cuda(), z_.cuda(), y_vec_.cuda()

# update D network
self.D_optimizer.zero_grad()
Expand Down Expand Up @@ -261,4 +260,4 @@ def load(self):
save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)

self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_G.pkl')))
self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))
self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_D.pkl')))

0 comments on commit 0d183bb

Please sign in to comment.