-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathtrain.py
456 lines (347 loc) · 14.7 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
# coding: utf-8
import re
import random
import time
import datetime
import math
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_packed_sequence, pack_padded_sequence
from masked_cross_entropy import *
import matplotlib
matplotlib.use('TKAgg')
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import numpy as np
import pickle
from preprocess import Voc
from preprocess import MAX_LENGTH
USE_CUDA = True
PAD_token = 0
SOS_token = 1
EOS_token = 2
with open("save/pairs.pickle", "rb") as f:
pairs = pickle.load(f)
pairs = pairs[:-250]
with open("save/voc.pickle", "rb") as f:
voc = pickle.load(f)
# Return a list of indexes, one for each word in the sentence, plus EOS
def indexes_from_sentence(voc, sentence):
return [voc.word2index[word] for word in sentence.split(' ') if word in voc.word2index] + [EOS_token]
# Pad a with the PAD symbol
def pad_seq(seq, max_length):
seq += [PAD_token for i in range(max_length - len(seq))]
return seq
def random_batch(batch_size):
input_seqs = []
target_seqs = []
# Choose random pairs
for i in range(batch_size):
pair = random.choice(pairs)
input_seqs.append(indexes_from_sentence(voc, pair[0]))
target_seqs.append(indexes_from_sentence(voc, pair[1]))
# Zip into pairs, sort by length (descending), unzip
seq_pairs = sorted(zip(input_seqs, target_seqs), key=lambda p: len(p[0]), reverse=True)
input_seqs, target_seqs = zip(*seq_pairs) # unzip
# For input and target sequences, get array of lengths and pad with 0s to max length
input_lengths = [len(s) for s in input_seqs]
input_padded = [pad_seq(s, max(input_lengths)) for s in input_seqs]
target_lengths = [len(s) for s in target_seqs]
target_padded = [pad_seq(s, max(target_lengths)) for s in target_seqs]
# Turn padded arrays into (batch_size x max_len) tensors, transpose into (max_len x batch_size)
input_var = Variable(torch.LongTensor(input_padded)).transpose(0, 1)
target_var = Variable(torch.LongTensor(target_padded)).transpose(0, 1)
if USE_CUDA:
input_var = input_var.cuda()
target_var = target_var.cuda()
return input_var, input_lengths, target_var, target_lengths
# print(random_batch(2))
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size, n_layers=1, dropout=0.1):
super(EncoderRNN, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.n_layers = n_layers
self.dropout = dropout
self.embedding = nn.Embedding(input_size, hidden_size)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=self.dropout, bidirectional=True)
def forward(self, input_seqs, input_lengths, hidden=None):
# Note: we run this all at once (over multiple batches of multiple sequences)
embedded = self.embedding(input_seqs)
packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, input_lengths)
outputs, hidden = self.gru(packed, hidden)
outputs, output_lengths = torch.nn.utils.rnn.pad_packed_sequence(outputs) # unpack (back to padded)
outputs = outputs[:, :, :self.hidden_size] + outputs[:, : ,self.hidden_size:] # Sum bidirectional outputs
# print("outputs", outputs.size())
return outputs, hidden
class Attn(nn.Module):
def __init__(self, method, hidden_size):
super(Attn, self).__init__()
self.method = method
self.hidden_size = hidden_size
if self.method == 'general':
self.attn = nn.Linear(self.hidden_size, hidden_size)
elif self.method == 'concat':
self.attn = nn.Linear(self.hidden_size * 2, hidden_size)
self.v = nn.Parameter(torch.FloatTensor(1, hidden_size))
def forward(self, hidden, encoder_outputs):
max_len = encoder_outputs.size(0)
this_batch_size = encoder_outputs.size(1)
# Create variable to store attention energies
attn_energies = Variable(torch.zeros(this_batch_size, max_len)) # B x S
if USE_CUDA:
attn_energies = attn_energies.cuda()
# For each batch of encoder outputs
for b in range(this_batch_size):
# Calculate energy for each encoder output
for i in range(max_len):
attn_energies[b, i] = self.score(hidden[:, b], encoder_outputs[i, b].unsqueeze(0))
# Normalize energies to weights in range 0 to 1, resize to B x 1 x S
return F.softmax(attn_energies).unsqueeze(1)
def score(self, hidden, encoder_output):
if self.method == 'dot':
energy = hidden.squeeze(0).dot(encoder_output.squeeze(0))
return energy
elif self.method == 'general':
energy = self.attn(encoder_output)
energy = hidden.squeeze(0).dot(energy.squeeze(0))
return energy
elif self.method == 'concat':
energy = self.attn(torch.cat((hidden, encoder_output), 1))
energy = self.v.squeeze(0).dot(energy.squeeze(0))
return energy
class LuongAttnDecoderRNN(nn.Module):
def __init__(self, attn_model, hidden_size, output_size, n_layers=1, dropout=0.1):
super(LuongAttnDecoderRNN, self).__init__()
# Keep for reference
self.attn_model = attn_model
self.hidden_size = hidden_size
self.output_size = output_size
self.n_layers = n_layers
self.dropout = dropout
# Define layers
self.embedding = nn.Embedding(output_size, hidden_size)
self.embedding_dropout = nn.Dropout(dropout)
self.gru = nn.GRU(hidden_size, hidden_size, n_layers, dropout=dropout)
self.concat = nn.Linear(hidden_size * 2, hidden_size)
self.out = nn.Linear(hidden_size, output_size)
# Choose attention model
if attn_model != 'none':
self.attn = Attn(attn_model, hidden_size)
def forward(self, input_seq, last_hidden, encoder_outputs):
# Note: we run this one step at a time
# Get the embedding of the current input word (last output word)
batch_size = input_seq.size(0)
embedded = self.embedding(input_seq)
embedded = self.embedding_dropout(embedded)
embedded = embedded.view(1, batch_size, self.hidden_size) # S=1 x B x N
# Get current hidden state from input word and last hidden state
rnn_output, hidden = self.gru(embedded, last_hidden)
# Calculate attention from current RNN state and all encoder outputs;
# apply to encoder outputs to get weighted average
attn_weights = self.attn(rnn_output, encoder_outputs)
context = attn_weights.bmm(encoder_outputs.transpose(0, 1)) # B x S=1 x N
# Attentional vector using the RNN hidden state and context vector
# concatenated together (Luong eq. 5)
rnn_output = rnn_output.squeeze(0) # S=1 x B x N -> B x N
context = context.squeeze(1) # B x S=1 x N -> B x N
concat_input = torch.cat((rnn_output, context), 1)
concat_output = F.tanh(self.concat(concat_input))
# Finally predict next token (Luong eq. 6, without softmax)
output = self.out(concat_output)
# Return final output, hidden state, and attention weights (for visualization)
return output, hidden, attn_weights
def train(input_batches, input_lengths, target_batches, target_lengths, encoder, decoder, encoder_optimizer, decoder_optimizer, criterion, max_length=MAX_LENGTH):
# Zero gradients of both optimizers
encoder_optimizer.zero_grad()
decoder_optimizer.zero_grad()
loss = 0 # Added onto for each word
# Run words through encoder
encoder_outputs, encoder_hidden = encoder(input_batches, input_lengths, None)
# Prepare input and output variables
decoder_input = Variable(torch.LongTensor([SOS_token] * batch_size))
decoder_hidden = encoder_hidden[:decoder.n_layers] # Use last (forward) hidden state from encoder
max_target_length = max(target_lengths)
all_decoder_outputs = Variable(torch.zeros(max_target_length, batch_size, decoder.output_size))
# Move new Variables to CUDA
if USE_CUDA:
decoder_input = decoder_input.cuda()
all_decoder_outputs = all_decoder_outputs.cuda()
# Run through decoder one time step at a time
for t in range(max_target_length):
decoder_output, decoder_hidden, decoder_attn = decoder(
decoder_input, decoder_hidden, encoder_outputs
)
all_decoder_outputs[t] = decoder_output
decoder_input = target_batches[t] # Next input is current target
# Loss calculation and backpropagation
loss = masked_cross_entropy(
all_decoder_outputs.transpose(0, 1).contiguous(), # -> batch x seq
target_batches.transpose(0, 1).contiguous(), # -> batch x seq
target_lengths
)
loss.backward()
# Clip gradient norms
ec = torch.nn.utils.clip_grad_norm(encoder.parameters(), clip)
dc = torch.nn.utils.clip_grad_norm(decoder.parameters(), clip)
# Update parameters with optimizers
encoder_optimizer.step()
decoder_optimizer.step()
return loss.data[0], ec, dc
# Configure models
attn_model = 'dot'
hidden_size = 500
n_layers = 2
dropout = 0.1
batch_size = 100 # 100
# Configure training/optimization
clip = 50.0
teacher_forcing_ratio = 1 # 0.5
learning_rate = 0.0001
decoder_learning_ratio = 5.0
n_epochs = 10000 #50000
epoch = 0
plot_every = 50 #20
print_every = 50 #100
# Initialize models
encoder = EncoderRNN(voc.n_words, hidden_size, n_layers, dropout=dropout)
decoder = LuongAttnDecoderRNN(attn_model, hidden_size, voc.n_words, n_layers, dropout=dropout)
# Initialize optimizers and criterion
encoder_optimizer = optim.Adam(encoder.parameters(), lr=learning_rate)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=learning_rate * decoder_learning_ratio)
criterion = nn.CrossEntropyLoss()
# Move models to GPU
if USE_CUDA:
encoder.cuda()
decoder.cuda()
import sconce
job = sconce.Job('seq2seq-translate', {
'attn_model': attn_model,
'n_layers': n_layers,
'dropout': dropout,
'hidden_size': hidden_size,
'learning_rate': learning_rate,
'clip': clip,
'teacher_forcing_ratio': teacher_forcing_ratio,
'decoder_learning_ratio': decoder_learning_ratio,
})
job.plot_every = plot_every
job.log_every = print_every
# Keep track of time elapsed and running averages
start = time.time()
plot_losses = []
print_loss_total = 0 # Reset every print_every
plot_loss_total = 0 # Reset every plot_every
def as_minutes(s):
m = math.floor(s / 60)
s -= m * 60
return '%dm %ds' % (m, s)
def time_since(since, percent):
now = time.time()
s = now - since
es = s / (percent)
rs = es - s
return '%s (- %s)' % (as_minutes(s), as_minutes(rs))
def evaluate(input_seq, max_length=MAX_LENGTH):
input_lengths = [len(input_seq)]
input_seqs = [indexes_from_sentence(voc, input_seq)]
input_batches = Variable(torch.LongTensor(input_seqs), volatile=True).transpose(0, 1)
if USE_CUDA:
input_batches = input_batches.cuda()
# Set to not-training mode to disable dropout
encoder.train(False)
decoder.train(False)
# Run through encoder
encoder_outputs, encoder_hidden = encoder(input_batches, input_lengths, None)
# Create starting vectors for decoder
decoder_input = Variable(torch.LongTensor([SOS_token]), volatile=True) # SOS
decoder_hidden = encoder_hidden[:decoder.n_layers] # Use last (forward) hidden state from encoder
if USE_CUDA:
decoder_input = decoder_input.cuda()
# Store output words and attention states
decoded_words = []
decoder_attentions = torch.zeros(max_length + 1, max_length + 1)
# Run through decoder
for di in range(max_length):
decoder_output, decoder_hidden, decoder_attention = decoder(
decoder_input, decoder_hidden, encoder_outputs
)
decoder_attentions[di,:decoder_attention.size(2)] += decoder_attention.squeeze(0).squeeze(0).cpu().data
# Choose top word from output
topv, topi = decoder_output.data.topk(1)
ni = topi[0][0]
if ni == EOS_token:
decoded_words.append('<EOS>')
break
else:
decoded_words.append(voc.index2word[ni])
# Next input is chosen word
decoder_input = Variable(torch.LongTensor([ni]))
if USE_CUDA: decoder_input = decoder_input.cuda()
# Set back to training mode
encoder.train(True)
decoder.train(True)
return decoded_words, decoder_attentions[:di+1, :len(encoder_outputs)]
# Begin!
ecs = []
dcs = []
eca = 0
dca = 0
import os
if os.path.isfile("save/encoder.pkl"):
encoder.load_state_dict(torch.load("save/encoder.pkl"))
print("loaded encoder state_dict!")
if os.path.isfile("save/decoder.pkl"):
decoder.load_state_dict(torch.load("save/decoder.pkl"))
print("loaded decoder state_dict!")
while epoch < n_epochs:
epoch += 1
# Get training data for this cycle
input_batches, input_lengths, target_batches, target_lengths = random_batch(batch_size)
# Run the train function
loss, ec, dc = train(
input_batches, input_lengths, target_batches, target_lengths,
encoder, decoder,
encoder_optimizer, decoder_optimizer, criterion
)
# Keep track of loss
print_loss_total += loss
plot_loss_total += loss
eca += ec
dca += dc
job.record(epoch, loss)
if epoch % print_every == 0:
print_loss_avg = print_loss_total / print_every
print_loss_total = 0
print_summary = '%s (%d %d%%) %.4f' % (time_since(start, epoch / n_epochs), epoch, epoch / n_epochs * 100, print_loss_avg)
print(print_summary)
torch.save(encoder.state_dict(), "save/encoder.pkl")
torch.save(decoder.state_dict(), "save/decoder.pkl")
print("saved...")
if epoch % plot_every == 0:
plot_loss_avg = plot_loss_total / plot_every
plot_losses.append(plot_loss_avg)
plot_loss_total = 0
# TODO: Running average helper
ecs.append(eca / plot_every)
dcs.append(dca / plot_every)
eca = 0
dca = 0
def show_plot(points):
plt.figure()
fig, ax = plt.subplots()
loc = ticker.MultipleLocator(base=0.2) # put ticks at regular intervals
ax.yaxis.set_major_locator(loc)
plt.plot(points)
plt.savefig('save/outputs/train_losses.png', dpi=300)
show_plot(plot_losses)
output_words, attentions = evaluate("很 懷念 住 宿舍 的 日子 … …")
print("input: 很 懷念 住 宿舍 的 日子 … …")
print("target: 現在 不 懷念 , 以後 一定 會 懷念")
print(output_words)
cax = plt.matshow(attentions.numpy(), cmap='bone')
plt.colorbar(cax)
plt.savefig('save/outputs/matshow.png', dpi=300)