diff --git a/14_1_seq2seq.py b/14_1_seq2seq.py index d92aebc..db07dfc 100644 --- a/14_1_seq2seq.py +++ b/14_1_seq2seq.py @@ -114,7 +114,7 @@ def translate(enc_input='thisissungkim.iloveyou.', predict_len=100, temperature= for i, (srcs, targets) in enumerate(train_loader): train_loss = train(srcs[0], targets[0]) # Batch is 1 - if i % 100 is 0: + if i % 100 == 0: print('[(%d %d%%) %.4f]' % (epoch, epoch / N_EPOCH * 100, train_loss)) print(translate(srcs[0]), '\n')