Skip to content

Commit

Permalink
Fix dep
Browse files Browse the repository at this point in the history
  • Loading branch information
r9y9 committed Aug 11, 2018
1 parent 5852640 commit 2e57916
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 11 deletions.
2 changes: 1 addition & 1 deletion gantts/seqloss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def sequence_mask(sequence_length, max_len=None):
class MaskedMSELoss(nn.Module):
def __init__(self):
super(MaskedMSELoss, self).__init__()
self.criterion = nn.MSELoss(size_average=False)
self.criterion = nn.MSELoss(reduction="sum")

def forward(self, input, target, lengths=None, mask=None, max_len=None):
if lengths is None and mask is None:
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def run(self):
},
install_requires=[
"numpy",
"torch >= 0.4.0",
],
extras_require={
"train": [
Expand Down
20 changes: 10 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,15 +255,15 @@ def update_discriminator(model_d, optimizer_d, x, y_static, y_hat_static, length
y_static_adv = torch.cat((x, y_static_adv), -1)
y_hat_static_adv = torch.cat((x, y_hat_static_adv), -1)

T = mask.sum().data[0]
T = mask.sum().item()

# Real
D_real = model_d(y_static_adv, lengths=lengths)
real_correct_count = ((D_real > 0.5).float() * mask).sum().data[0]
real_correct_count = ((D_real > 0.5).float() * mask).sum().item()

# Fake
D_fake = model_d(y_hat_static_adv, lengths=lengths)
fake_correct_count = ((D_fake < 0.5).float() * mask).sum().data[0]
fake_correct_count = ((D_fake < 0.5).float() * mask).sum().item()

# Loss
loss_real_d = -(torch.log(D_real + eps) * mask).sum() / T
Expand All @@ -272,18 +272,18 @@ def update_discriminator(model_d, optimizer_d, x, y_static, y_hat_static, length

if phase == "train":
loss_d.backward(retain_graph=True)
torch.nn.utils.clip_grad_norm(model_d.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(model_d.parameters(), 1.0)
optimizer_d.step()

return loss_d.data[0], loss_fake_d.data[0], loss_real_d.data[0],\
return loss_d.item(), loss_fake_d.item(), loss_real_d.item(),\
real_correct_count, fake_correct_count


def update_generator(model_g, model_d, optimizer_g,
x, y, y_hat, y_static, y_hat_static,
adv_w, lengths, mask, phase,
mse_w=None, mge_w=None, eps=1e-20):
T = mask.sum().data[0]
T = mask.sum().item()

criterion = MaskedMSELoss()

Expand Down Expand Up @@ -314,10 +314,10 @@ def update_generator(model_g, model_d, optimizer_g,
loss_g = (mse_w * loss_mse + mge_w * loss_mge) + adv_w * loss_adv
if phase == "train":
loss_g.backward()
torch.nn.utils.clip_grad_norm(model_g.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(model_g.parameters(), 1.0)
optimizer_g.step()

return loss_mse.data[0], loss_mge.data[0], loss_adv.data[0], loss_g.data[0]
return loss_mse.item(), loss_mge.item(), loss_adv.item(), loss_g.item()


def exp_lr_scheduler(optimizer, epoch, nepoch, init_lr=0.0001, lr_decay_epoch=100):
Expand Down Expand Up @@ -529,7 +529,7 @@ def train_loop(models, optimizers, dataset_loaders,
y, len(hp.windows), hp.stream_sizes, hp.has_dynamic_features)

# Num frames in batch
total_num_frames += sorted_lengths.float().sum().data[0]
total_num_frames += sorted_lengths.float().sum().item()

# Mask
mask = sequence_mask(sorted_lengths).unsqueeze(-1)
Expand All @@ -555,7 +555,7 @@ def train_loop(models, optimizers, dataset_loaders,
y_hat_static_ref, lengths=cpu_sorted_lengths)
# Count samples classified as natural, while inputs are
# actually generated.
regard_fake_as_natural += ((target > 0.5).float() * mask).sum().data[0]
regard_fake_as_natural += ((target > 0.5).float() * mask).sum().item()

### Update discriminator ###
# Natural: 1, Genrated: 0
Expand Down

0 comments on commit 2e57916

Please sign in to comment.