-
Notifications
You must be signed in to change notification settings - Fork 81
/
Copy pathLSTM.lua
34 lines (26 loc) · 1.06 KB
/
LSTM.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
-- adapted from: wojciechz/learning_to_execute on github
local LSTM = {}
-- Creates one timestep of one LSTM
function LSTM.lstm(opt)
local x = nn.Identity()()
local prev_c = nn.Identity()()
local prev_h = nn.Identity()()
function new_input_sum()
-- transforms input
local i2h = nn.Linear(opt.rnn_size, opt.rnn_size)(x)
-- transforms previous timestep's output
local h2h = nn.Linear(opt.rnn_size, opt.rnn_size)(prev_h)
return nn.CAddTable()({i2h, h2h})
end
local in_gate = nn.Sigmoid()(new_input_sum())
local forget_gate = nn.Sigmoid()(new_input_sum())
local out_gate = nn.Sigmoid()(new_input_sum())
local in_transform = nn.Tanh()(new_input_sum())
local next_c = nn.CAddTable()({
nn.CMulTable()({forget_gate, prev_c}),
nn.CMulTable()({in_gate, in_transform})
})
local next_h = nn.CMulTable()({out_gate, nn.Tanh()(next_c)})
return nn.gModule({x, prev_c, prev_h}, {next_c, next_h})
end
return LSTM