forked from graykode/nlp-tutorial
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathSeq2Seq(Attention).py
121 lines (95 loc) · 5.25 KB
/
Seq2Seq(Attention).py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# %%
# code by Tae Hwan Jung @graykode
# Reference : https://github.com/hunkim/PyTorchZeroToAll/blob/master/14_2_seq2seq_att.py
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
# S: Symbol that shows starting of decoding input
# E: Symbol that shows starting of decoding output
# P: Symbol that will fill in blank sequence if current batch data size is short than time steps
def make_batch():
input_batch = [np.eye(n_class)[[word_dict[n] for n in sentences[0].split()]]]
output_batch = [np.eye(n_class)[[word_dict[n] for n in sentences[1].split()]]]
target_batch = [[word_dict[n] for n in sentences[2].split()]]
# make tensor
return torch.FloatTensor(input_batch), torch.FloatTensor(output_batch), torch.LongTensor(target_batch)
class Attention(nn.Module):
def __init__(self):
super(Attention, self).__init__()
self.enc_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
self.dec_cell = nn.RNN(input_size=n_class, hidden_size=n_hidden, dropout=0.5)
# Linear for attention
self.attn = nn.Linear(n_hidden, n_hidden)
self.out = nn.Linear(n_hidden * 2, n_class)
def forward(self, enc_inputs, hidden, dec_inputs):
enc_inputs = enc_inputs.transpose(0, 1) # enc_inputs: [n_step(=n_step, time step), batch_size, n_class]
dec_inputs = dec_inputs.transpose(0, 1) # dec_inputs: [n_step(=n_step, time step), batch_size, n_class]
# enc_outputs : [n_step, batch_size, num_directions(=1) * n_hidden], matrix F
# enc_hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
enc_outputs, enc_hidden = self.enc_cell(enc_inputs, hidden)
trained_attn = []
hidden = enc_hidden
n_step = len(dec_inputs)
model = torch.empty([n_step, 1, n_class])
for i in range(n_step): # each time step
# dec_output : [n_step(=1), batch_size(=1), num_directions(=1) * n_hidden]
# hidden : [num_layers(=1) * num_directions(=1), batch_size(=1), n_hidden]
dec_output, hidden = self.dec_cell(dec_inputs[i].unsqueeze(0), hidden)
attn_weights = self.get_att_weight(dec_output, enc_outputs) # attn_weights : [1, 1, n_step]
trained_attn.append(attn_weights.squeeze().data.numpy())
# matrix-matrix product of matrices [1,1,n_step] x [1,n_step,n_hidden] = [1,1,n_hidden]
context = attn_weights.bmm(enc_outputs.transpose(0, 1))
dec_output = dec_output.squeeze(0) # dec_output : [batch_size(=1), num_directions(=1) * n_hidden]
context = context.squeeze(1) # [1, num_directions(=1) * n_hidden]
model[i] = self.out(torch.cat((dec_output, context), 1))
# make model shape [n_step, n_class]
return model.transpose(0, 1).squeeze(0), trained_attn
def get_att_weight(self, dec_output, enc_outputs): # get attention weight one 'dec_output' with 'enc_outputs'
n_step = len(enc_outputs)
attn_scores = torch.zeros(n_step) # attn_scores : [n_step]
for i in range(n_step):
attn_scores[i] = self.get_att_score(dec_output, enc_outputs[i])
# Normalize scores to weights in range 0 to 1
return F.softmax(attn_scores).view(1, 1, -1)
def get_att_score(self, dec_output, enc_output): # enc_outputs [batch_size, num_directions(=1) * n_hidden]
score = self.attn(enc_output) # score : [batch_size, n_hidden]
return torch.dot(dec_output.view(-1), score.view(-1)) # inner product make scalar value
if __name__ == '__main__':
n_step = 5 # number of cells(= number of Step)
n_hidden = 128 # number of hidden units in one cell
sentences = ['ich mochte ein bier P', 'S i want a beer', 'i want a beer E']
word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
number_dict = {i: w for i, w in enumerate(word_list)}
n_class = len(word_dict) # vocab list
# hidden : [num_layers(=1) * num_directions(=1), batch_size, n_hidden]
hidden = torch.zeros(1, 1, n_hidden)
model = Attention()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
input_batch, output_batch, target_batch = make_batch()
# Train
for epoch in range(2000):
optimizer.zero_grad()
output, _ = model(input_batch, hidden, output_batch)
loss = criterion(output, target_batch.squeeze(0))
if (epoch + 1) % 400 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss))
loss.backward()
optimizer.step()
# Test
test_batch = [np.eye(n_class)[[word_dict[n] for n in 'SPPPP']]]
test_batch = torch.FloatTensor(test_batch)
predict, trained_attn = model(input_batch, hidden, test_batch)
predict = predict.data.max(1, keepdim=True)[1]
print(sentences[0], '->', [number_dict[n.item()] for n in predict.squeeze()])
# Show Attention
fig = plt.figure(figsize=(5, 5))
ax = fig.add_subplot(1, 1, 1)
ax.matshow(trained_attn, cmap='viridis')
ax.set_xticklabels([''] + sentences[0].split(), fontdict={'fontsize': 14})
ax.set_yticklabels([''] + sentences[2].split(), fontdict={'fontsize': 14})
plt.show()