Skip to content

Commit 8f8677a

Browse files
committed
C implementation of LSTM -- REVIEWABLE
* New "sparse/size" representation * Full LSTM in C * VSeqLSTM to wrap this data representation + C implementation * Augmentation of the VariableLength decorator with this data representation from an array of tensors * unit tests * speed tests
1 parent fee152c commit 8f8677a

16 files changed

+2095
-54
lines changed

CMakeLists.txt

+11
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,20 @@ CMAKE_POLICY(VERSION 2.6)
55

66
FIND_PACKAGE(Torch REQUIRED)
77

8+
ADD_SUBDIRECTORY(lib)
9+
810
SET(BUILD_STATIC YES) # makes sure static targets are enabled in ADD_TORCH_PACKAGE
911

1012
SET(CMAKE_C_FLAGS "--std=c99 -pedantic -Werror -Wall -Wextra -Wno-unused-function -D_GNU_SOURCE ${CMAKE_C_FLAGS}")
1113
SET(src
1214
init.c
1315
)
16+
17+
FILE(STRINGS lib/THRNN/generic/THRNN.h THRNN_headers NEWLINE_CONSUME)
18+
FILE(WRITE THRNN_h.lua "return [[")
19+
FILE(APPEND THRNN_h.lua ${THRNN_headers})
20+
FILE(APPEND THRNN_h.lua "]]")
21+
1422
SET(luasrc
1523
init.lua
1624
AbstractRecurrent.lua
@@ -36,6 +44,7 @@ SET(luasrc
3644
SeqBLSTM.lua
3745
SeqGRU.lua
3846
SeqLSTM.lua
47+
VSeqLSTM.lua
3948
Sequencer.lua
4049
SequencerCriterion.lua
4150
test/bigtest.lua
@@ -75,6 +84,8 @@ SET(luasrc
7584
deprecated/FastLSTM.lua
7685
deprecated/GRU.lua
7786
deprecated/LSTM.lua
87+
THRNN.lua
88+
THRNN_h.lua
7889
)
7990

8091
ADD_TORCH_PACKAGE(rnn "${src}" "${luasrc}" "An RNN library for Torch")

LookupTableMaskZero.lua

+3-3
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,17 @@ function LookupTableMaskZero:__init(nIndex, nOutput)
55
end
66

77
function LookupTableMaskZero:updateOutput(input)
8-
self.weight[1]:zero()
8+
self.weight[1]:zero()
99
if self.__input and (torch.type(self.__input) ~= torch.type(input)) then
1010
self.__input = nil -- fixes old casting bug
1111
end
1212
self.__input = self.__input or input.new()
1313
self.__input:resizeAs(input):add(input, 1)
14-
return parent.updateOutput(self, self.__input)
14+
return parent.updateOutput(self, self.__input)
1515
end
1616

1717
function LookupTableMaskZero:accGradParameters(input, gradOutput, scale)
18-
parent.accGradParameters(self, self.__input, gradOutput, scale)
18+
parent.accGradParameters(self, self.__input, gradOutput, scale)
1919
end
2020

2121
function LookupTableMaskZero:type(type, cache)

Module.lua

+9
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,15 @@ function Module:setZeroMask(zeroMask)
3838
end
3939
end
4040

41+
function Module:setContext(context)
42+
if self.modules then
43+
for i, module in ipairs(self.modules) do
44+
module:setContext(context)
45+
end
46+
end
47+
self.__context = context
48+
end
49+
4150
function Module:stepClone(shareParams, shareGradParams)
4251
return self:sharedClone(shareParams, shareGradParams, true)
4352
end

SeqLSTM.lua

+1-1
Original file line numberDiff line numberDiff line change
@@ -457,4 +457,4 @@ function SeqLSTM:toRecLSTM()
457457
end
458458

459459
return lstm
460-
end
460+
end

THRNN.lua

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
local ffi = require 'ffi'
2+
3+
local THRNN = {}
4+
5+
6+
local generic_THRNN_h = require 'rnn.THRNN_h'
7+
-- strip all lines starting with #
8+
-- to remove preprocessor directives originally present
9+
-- in THRNN.h
10+
generic_THRNN_h = generic_THRNN_h:gsub("\n#[^\n]*", "")
11+
generic_THRNN_h = generic_THRNN_h:gsub("^#[^\n]*\n", "")
12+
13+
-- THGenerator struct declaration copied from torch7/lib/TH/THRandom.h
14+
local base_declarations = [[
15+
typedef void THRNNState;
16+
17+
typedef struct {
18+
unsigned long the_initial_seed;
19+
int left;
20+
int seeded;
21+
unsigned long next;
22+
unsigned long state[624]; /* the array for the state vector 624 = _MERSENNE_STATE_N */
23+
double normal_x;
24+
double normal_y;
25+
double normal_rho;
26+
int normal_is_valid;
27+
} THGenerator;
28+
]]
29+
30+
-- polyfill for LUA 5.1
31+
if not package.searchpath then
32+
local sep = package.config:sub(1,1)
33+
function package.searchpath(mod, path)
34+
mod = mod:gsub('%.', sep)
35+
for m in path:gmatch('[^;]+') do
36+
local nm = m:gsub('?', mod)
37+
local f = io.open(nm, 'r')
38+
if f then
39+
f:close()
40+
return nm
41+
end
42+
end
43+
end
44+
end
45+
46+
-- load libTHRNN
47+
THRNN.C = ffi.load(package.searchpath('libTHRNN', package.cpath))
48+
49+
ffi.cdef(base_declarations)
50+
51+
-- expand macros, allow to use original lines from lib/THRNN/generic/THRNN.h
52+
local preprocessed = string.gsub(generic_THRNN_h, 'TH_API void THRNN_%(([%a%d_]+)%)', 'void THRNN_TYPE%1')
53+
54+
local replacements =
55+
{
56+
{
57+
['TYPE'] = 'Double',
58+
['real'] = 'double',
59+
['THTensor'] = 'THDoubleTensor',
60+
['THIndexTensor'] = 'THLongTensor',
61+
['THIntegerTensor'] = 'THIntTensor',
62+
['THIndex_t'] = 'long',
63+
['THInteger_t'] = 'int'
64+
},
65+
{
66+
['TYPE'] = 'Float',
67+
['real'] = 'float',
68+
['THTensor'] = 'THFloatTensor',
69+
['THIndexTensor'] = 'THLongTensor',
70+
['THIntegerTensor'] = 'THIntTensor',
71+
['THIndex_t'] = 'long',
72+
['THInteger_t'] = 'int'
73+
}
74+
}
75+
76+
-- gsub(s, 'real', 'float') changes accreal to accfloat.
77+
-- typedef accfloat ahead of time.
78+
ffi.cdef("typedef double accfloat;")
79+
-- gsub(s, 'real', 'double') changes accreal to accfloat.
80+
-- typedef accdouble ahead of time
81+
ffi.cdef("typedef double accdouble;")
82+
83+
for i=1,#replacements do
84+
local r = replacements[i]
85+
local s = preprocessed
86+
for k,v in pairs(r) do
87+
s = string.gsub(s, k, v)
88+
end
89+
ffi.cdef(s)
90+
end
91+
92+
THRNN.NULL = ffi.NULL or nil
93+
94+
function THRNN.getState()
95+
return ffi.NULL or nil
96+
end
97+
98+
function THRNN.optionalTensor(t)
99+
return t and t:cdata() or THRNN.NULL
100+
end
101+
102+
local function extract_function_names(s)
103+
local t = {}
104+
for n in string.gmatch(s, 'TH_API void THRNN_%(([%a%d_]+)%)') do
105+
t[#t+1] = n
106+
end
107+
return t
108+
end
109+
110+
function THRNN.bind(lib, base_names, type_name, state_getter)
111+
local ftable = {}
112+
local prefix = 'THRNN_' .. type_name
113+
for i,n in ipairs(base_names) do
114+
-- use pcall since some libs might not support all functions (e.g. cunn)
115+
local ok,v = pcall(function() return lib[prefix .. n] end)
116+
if ok then
117+
ftable[n] = function(...) v(state_getter(), ...) end -- implicitely add state
118+
else
119+
print('not found: ' .. prefix .. n .. v)
120+
end
121+
end
122+
return ftable
123+
end
124+
125+
-- build function table
126+
local function_names = extract_function_names(generic_THRNN_h)
127+
128+
THRNN.kernels = {}
129+
THRNN.kernels['torch.FloatTensor'] = THRNN.bind(THRNN.C, function_names, 'Float', THRNN.getState)
130+
THRNN.kernels['torch.DoubleTensor'] = THRNN.bind(THRNN.C, function_names, 'Double', THRNN.getState)
131+
132+
torch.getmetatable('torch.FloatTensor').THRNN = THRNN.kernels['torch.FloatTensor']
133+
torch.getmetatable('torch.DoubleTensor').THRNN = THRNN.kernels['torch.DoubleTensor']
134+
135+
function THRNN.runKernel(f, type, ...)
136+
local ftable = THRNN.kernels[type]
137+
if not ftable then
138+
error('Unsupported tensor type: '..type)
139+
end
140+
local f = ftable[f]
141+
if not f then
142+
error(string.format("Function '%s' not found for tensor type '%s'.", f, type))
143+
end
144+
f(...)
145+
end
146+
147+
return THRNN

0 commit comments

Comments
 (0)