-
Notifications
You must be signed in to change notification settings - Fork 29
/
Copy pathcreate_model.lua
37 lines (29 loc) · 1.02 KB
/
create_model.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
require 'nn'
require 'requ'
function create_model(opt)
------------------------------------------------------------------------------
-- MODEL
------------------------------------------------------------------------------
local n_inputs = 4
local embedding_dim = 2
local n_classes = 3
-- OUR MODEL:
-- linear -> sigmoid/requ -> linear -> softmax
local model = nn.Sequential()
model:add(nn.Linear(n_inputs, embedding_dim))
if opt.nonlinearity_type == 'requ' then
model:add(nn.ReQU())
elseif opt.nonlinearity_type == 'sigmoid' then
model:add(nn.Sigmoid())
else
error('undefined nonlinearity_type ' .. tostring(opt.nonlinearity_type))
end
model:add(nn.Linear(embedding_dim, n_classes))
model:add(nn.LogSoftMax())
------------------------------------------------------------------------------
-- LOSS FUNCTION
------------------------------------------------------------------------------
local criterion = nn.ClassNLLCriterion()
return model, criterion
end
return create_model