|
| 1 | +local ffi = require 'ffi' |
| 2 | + |
| 3 | +local THRNN = {} |
| 4 | + |
| 5 | + |
| 6 | +local generic_THRNN_h = require 'rnn.THRNN_h' |
| 7 | +-- strip all lines starting with # |
| 8 | +-- to remove preprocessor directives originally present |
| 9 | +-- in THRNN.h |
| 10 | +generic_THRNN_h = generic_THRNN_h:gsub("\n#[^\n]*", "") |
| 11 | +generic_THRNN_h = generic_THRNN_h:gsub("^#[^\n]*\n", "") |
| 12 | + |
| 13 | +-- THGenerator struct declaration copied from torch7/lib/TH/THRandom.h |
| 14 | +local base_declarations = [[ |
| 15 | +typedef void THRNNState; |
| 16 | +
|
| 17 | +typedef struct { |
| 18 | + unsigned long the_initial_seed; |
| 19 | + int left; |
| 20 | + int seeded; |
| 21 | + unsigned long next; |
| 22 | + unsigned long state[624]; /* the array for the state vector 624 = _MERSENNE_STATE_N */ |
| 23 | + double normal_x; |
| 24 | + double normal_y; |
| 25 | + double normal_rho; |
| 26 | + int normal_is_valid; |
| 27 | +} THGenerator; |
| 28 | +]] |
| 29 | + |
| 30 | +-- polyfill for LUA 5.1 |
| 31 | +if not package.searchpath then |
| 32 | + local sep = package.config:sub(1,1) |
| 33 | + function package.searchpath(mod, path) |
| 34 | + mod = mod:gsub('%.', sep) |
| 35 | + for m in path:gmatch('[^;]+') do |
| 36 | + local nm = m:gsub('?', mod) |
| 37 | + local f = io.open(nm, 'r') |
| 38 | + if f then |
| 39 | + f:close() |
| 40 | + return nm |
| 41 | + end |
| 42 | + end |
| 43 | + end |
| 44 | +end |
| 45 | + |
| 46 | +-- load libTHRNN |
| 47 | +THRNN.C = ffi.load(package.searchpath('libTHRNN', package.cpath)) |
| 48 | + |
| 49 | +ffi.cdef(base_declarations) |
| 50 | + |
| 51 | +-- expand macros, allow to use original lines from lib/THRNN/generic/THRNN.h |
| 52 | +local preprocessed = string.gsub(generic_THRNN_h, 'TH_API void THRNN_%(([%a%d_]+)%)', 'void THRNN_TYPE%1') |
| 53 | + |
| 54 | +local replacements = |
| 55 | +{ |
| 56 | + { |
| 57 | + ['TYPE'] = 'Double', |
| 58 | + ['real'] = 'double', |
| 59 | + ['THTensor'] = 'THDoubleTensor', |
| 60 | + ['THIndexTensor'] = 'THLongTensor', |
| 61 | + ['THIntegerTensor'] = 'THIntTensor', |
| 62 | + ['THIndex_t'] = 'long', |
| 63 | + ['THInteger_t'] = 'int' |
| 64 | + }, |
| 65 | + { |
| 66 | + ['TYPE'] = 'Float', |
| 67 | + ['real'] = 'float', |
| 68 | + ['THTensor'] = 'THFloatTensor', |
| 69 | + ['THIndexTensor'] = 'THLongTensor', |
| 70 | + ['THIntegerTensor'] = 'THIntTensor', |
| 71 | + ['THIndex_t'] = 'long', |
| 72 | + ['THInteger_t'] = 'int' |
| 73 | + } |
| 74 | +} |
| 75 | + |
| 76 | +-- gsub(s, 'real', 'float') changes accreal to accfloat. |
| 77 | +-- typedef accfloat ahead of time. |
| 78 | +ffi.cdef("typedef double accfloat;") |
| 79 | +-- gsub(s, 'real', 'double') changes accreal to accfloat. |
| 80 | +-- typedef accdouble ahead of time |
| 81 | +ffi.cdef("typedef double accdouble;") |
| 82 | + |
| 83 | +for i=1,#replacements do |
| 84 | + local r = replacements[i] |
| 85 | + local s = preprocessed |
| 86 | + for k,v in pairs(r) do |
| 87 | + s = string.gsub(s, k, v) |
| 88 | + end |
| 89 | + ffi.cdef(s) |
| 90 | +end |
| 91 | + |
| 92 | +THRNN.NULL = ffi.NULL or nil |
| 93 | + |
| 94 | +function THRNN.getState() |
| 95 | + return ffi.NULL or nil |
| 96 | +end |
| 97 | + |
| 98 | +function THRNN.optionalTensor(t) |
| 99 | + return t and t:cdata() or THRNN.NULL |
| 100 | +end |
| 101 | + |
| 102 | +local function extract_function_names(s) |
| 103 | + local t = {} |
| 104 | + for n in string.gmatch(s, 'TH_API void THRNN_%(([%a%d_]+)%)') do |
| 105 | + t[#t+1] = n |
| 106 | + end |
| 107 | + return t |
| 108 | +end |
| 109 | + |
| 110 | +function THRNN.bind(lib, base_names, type_name, state_getter) |
| 111 | + local ftable = {} |
| 112 | + local prefix = 'THRNN_' .. type_name |
| 113 | + for i,n in ipairs(base_names) do |
| 114 | + -- use pcall since some libs might not support all functions (e.g. cunn) |
| 115 | + local ok,v = pcall(function() return lib[prefix .. n] end) |
| 116 | + if ok then |
| 117 | + ftable[n] = function(...) v(state_getter(), ...) end -- implicitely add state |
| 118 | + else |
| 119 | + print('not found: ' .. prefix .. n .. v) |
| 120 | + end |
| 121 | + end |
| 122 | + return ftable |
| 123 | +end |
| 124 | + |
| 125 | +-- build function table |
| 126 | +local function_names = extract_function_names(generic_THRNN_h) |
| 127 | + |
| 128 | +THRNN.kernels = {} |
| 129 | +THRNN.kernels['torch.FloatTensor'] = THRNN.bind(THRNN.C, function_names, 'Float', THRNN.getState) |
| 130 | +THRNN.kernels['torch.DoubleTensor'] = THRNN.bind(THRNN.C, function_names, 'Double', THRNN.getState) |
| 131 | + |
| 132 | +torch.getmetatable('torch.FloatTensor').THRNN = THRNN.kernels['torch.FloatTensor'] |
| 133 | +torch.getmetatable('torch.DoubleTensor').THRNN = THRNN.kernels['torch.DoubleTensor'] |
| 134 | + |
| 135 | +function THRNN.runKernel(f, type, ...) |
| 136 | + local ftable = THRNN.kernels[type] |
| 137 | + if not ftable then |
| 138 | + error('Unsupported tensor type: '..type) |
| 139 | + end |
| 140 | + local f = ftable[f] |
| 141 | + if not f then |
| 142 | + error(string.format("Function '%s' not found for tensor type '%s'.", f, type)) |
| 143 | + end |
| 144 | + f(...) |
| 145 | +end |
| 146 | + |
| 147 | +return THRNN |
0 commit comments