-
Notifications
You must be signed in to change notification settings - Fork 12
26 trainable
Related with the module require("aprilann.trainable")
This package implements the class trainable.supervised_trainer
,
which is a powerful tool when you want to train ANNs following standard
algorithms. Additionally it implements generic functions and iterators
to implement training loops (dataset iterator, and training function).
If you want to do some specific tricks, it is possible to modify in your script the methods described here, or to re-implement the functionality that you need.
The class trainable.supervised_trainer
is the most important piece in this
package. This class knows the standard API of ANN components, loss functions,
optimizers, datasets, and matrix objects, so, this class use them in the correct
way to train ANNs following standard algorithms.
The construction of a trainer needs at least an ANN component object, but optionally, it is possible to indicate loss function, bunch size (mini-batch) and optimizer object:
> -- a linear component (w1 and b1 are the weight names)
> c = ann.components.hyperplane{ input=10, output=10,
name = "hyperplane",
dot_product_name = "c-w1",
bias_name = "c-b1",
bias_weights = "b1",
dot_product_weights = "w1" }
> -- a trainer
> trainer = trainable.supervised_trainer(c)
The arguments of the constructor are positional, described here:
-
ANN component: an instance of
ann.components.base()
or any sub-class of it. The trainer uses a reference to the given object. It is a mandatory argument. -
Loss function: an instance of a sub-class of
ann.loss
. The trainer uses a reference to the given object. If not given, it will be mandatory intrain_*
andvalidate_*
methods. By default isnil
. -
Bunch-size (mini-batch): a number with the batch size used to compute gradients. Values between 32 and 1024 are usual, depending on the optimizer and in the task. If not given, it will be mandatory in
train_*
,validate_*
, oruse_*
methods. By default isnil
. -
Optimizer: an instance of a sub-class of
ann.optimizer
. The trainer uses a reference to the given object. If not given, the default optimizer isann.optimizer.sgd()
. -
Smooth flag: a boolean value indicating if the gradients must be smoothed. In case of
true
, gradients will be multiplied by1/sqrt(bunch_size+K)
, beingK
the number of times the corresponding weights matrix is shared across the ANN. By default this parameter istrue
.
> -- without bunch-size, without loss function,
> -- using default optimizer, smooth is true
> trainer = trainable.supervised_trainer(c)
> -- without bunch-size, using default optimizer, smooth is true
> trainer = trainable.supervised_trainer(c, ann.loss.mse())
> -- using default optimizer, smooth is true
> trainer = trainable.supervised_trainer(c, ann.loss.mse(), 64)
> -- smooth is true
> trainer = trainable.supervised_trainer(c, ann.loss.mse(), 64,
ann.optimizer.rprop())
> -- all arguments given
> trainer = trainable.supervised_trainer(c, ann.loss.mse(), 64,
ann.optimizer.rprop(), false)
trainer:set_loss_function(loss)
loss = trainer:get_loss_function()
trainer:set_optimizer(optimizer)
optimizer = trainer:get_optimizer()
Once the trainer
is instantiated, it is mandatory to build
the ANN by
calling the named method.
> trainer:build()
The build
method could be called without arguments, or giving a table
with the following optional fields:
-
input
: a number with the input size of the model. If not given, it will be set to the input of the given ANN component. If given, this size will be used as sanity check, so, it must be equal to the input size of the ANN component. -
output
: a number with the output size of the model. Idem as previous field, but with ANN component output. -
weights
: a dictionary (table) with strings as keys and instances ofmatrix
as values. If given, the connection objects used by the components will be taken from this dictionary, so, the sizes must be equals. It could be used to initialize certain components with a given initial parameters.
> trainer:build{
input = 10, output = 10,
weights = {
-- ann.connections allocates memory for weight matrices
w1 = ann.connections{ input=10, output=10 },
b1 = ann.connections{ input=1, output=10 },
}
}
> -- giving a table, initialized using its constructor
> trainer:build{
input = 10, output = 10,
weights = {
w1 = ann.connections{ input=10, output=10 },
b1 = ann.connections{ input=1, output=10 },
}
}
Once the build
method is called, it is recommended to not modify the structure
of the ANN component, if you need so, after your modifications the build
method must be called again. The connection weights must be
modified after calling build
, initializing them in some way. Another
possibility is to deserialize a previously serialized trainer
.
It is possible to save a trainer
object, always in built
state, using
util.serialize()
and util.deserialize()
functions.
> util.serialize(trainer, "mytrainer-binary.net")
> trainer2 = util.deserialize("mytrainer-binary.net")
trainer:randomize_weights{ ... }
The connection weights could be initialized randomly by using the method
trainer:randomize_weights
. This method receives a table with the following
fields:
-
inf
: a number with the inferior range bound. It is a mandatory field. -
sup
: a number with the superior range bound. It is a mandatory field. -
random
: arandom
object. It is a mandatory field. -
use_fanin
: a boolean value indicating if apply a factor fan-in of each layer for its initialization. It is an optional field. By default isfalse
. -
use_fanout
: a boolean value indicating if apply a factor fan-out of each layer for its initialization. It is an optional field. By default isfalse
-
name_match
: a string with a Lua pattern used to filter which connection objects will be initialized. It is an optional field, by default is.*
.
The weights will be initilized in the range [c*inf, c*sup]
where c
is
a factor which depends on use_fanin
and use_fanout
arguments:
-
If none given, then
c=1
for all weight layers. -
If given only
use_fanin=true
, thenc
is computed depending in the fan-in of each layer, beingc = 1/sqrt(fanin)
. -
If given only
use_fanout=true
, thenc
is computed depending in the fan-out of each layer, beingc = 1/sqrt(fanout)
. -
If both given,
use_fanin=true
anduse_fanout=true
, thenc
is computed depending in the fan-in and fan-out of each layer, beingc = 1/sqrt(fanin + fanout)
.
> -- initilize only bias weights
> trainer:randomize_weights{
inf = -0.1,
sup = 0.1,
random = random(213924),
name_match = "b.*",
}
> -- initilize only non-bias weights
> trainer:randomize_weights{
inf = -0.1,
sup = 0.1,
random = random(213924),
name_match = "w.*",
use_fanin = true,
use_fanout = true,
}
> -- initilize all the weights
> trainer:randomize_weights{
inf = -0.1,
sup = 0.1,
random = random(213924),
}
Once the trainer
is built, it is possible to do some introspection in order
to modify or execute methods of connection weights.
number = trainer:count_weights( [pattern=.*] )
This method returns the number of connection weights in the current trainer
.
Optionally the method receives a Lua pattern filtering the counting process
to only the weights whom name matches the pattern.
> = trainer:count_weights()
2
number = trainer:weights(name)
This method returns the matrix
object with the given name
.
> w1 = trainer:weights("w1")
> = type(w1)
matrix
> b1 = trainer:weights("b1")
> = type(b1)
matrix
table = trainer:get_components_of(weights_name)
This method returns a table with ann.components
objects which share the given
weights_name
connection weights. It returns a table because a connection
weights object could be shared by more than one ANN components.
> iterator(ipairs( trainer:get_components_of("w1") )):apply(print)
1 instance 0xfcf5b0 of ann.components.base
> iterator(ipairs( trainer:get_components_of("b1") )):apply(print)
1 instance 0xff3c90 of ann.components.base
table = trainer:get_weights_table()
This method returns a dictionary (table) with all the connection weights. This
dictionary has the same structure as the weights
field of trainer:build(...)
method.
> weights_table = trainer:get_weights_table()
> print(weights_table)
table: 0x1310a00
This method returns a Lua iterator function which iterates over all the
connection weights which name matches the given pattern
.
> for cnn_name,cnn in trainer:iterate_weights() do print(cnn_name) end
b1
w1
> iterator( trainer:iterate_weights("w.*") ):select(1):apply(print)
w1
number = trainer:norm2([pattern=.*])
This method computes the 2-norm of the connection weight objects whom name matches the given pattern.
> = trainer:norm2()
0.24416591227055
> = trainer:norm2("b.")
0
> = trainer:norm2("w.")
0.24416591227055
number = trainer:size()
This methods returns the number of parameters (weights) in the current ANN component.
> = trainer:size()
110
By using these methods it is possible to manipulate by-hand the connection weights, as in the following example, which initializes to zero the bias connections:
> for _,cnn in trainer:iterate_weights("b.*") do cnn:zeros() end
REMEMBER that connection weight objects are matrix
instances.
The following code shows at screen all the weight matrices:
> iterator(trainer:iterate_weights()):apply(print)
0
0
0
0
0
0
0
0
0
0
# Matrix of size [10,1] [0x10aca60 data= 0x10acb20]
0.0575503 0.0516265 -0.0306808 0.035404 0.0118243 ...
0.0281929 0.0877731 0.0842627 -0.0379949 -0.091877 ...
-0.0332023 0.0576623 -0.0335078 0.0251189 -0.0578111 ...
-0.0335119 0.0162495 -0.00910386 -0.0949801 0.00303258 ...
-0.0361652 -0.0389352 0.0628194 -0.0622919 -0.0206459 ...
0.0583717 0.0910834 -0.0889903 -0.0142328 -0.0750175 ...
-0.0895628 0.0412171 0.0308301 -0.0680314 0.0948681 ...
-0.00439932 0.0975324 0.00736945 0.013484 -0.079681 ...
-0.0859327 0.0332012 0.0374489 -0.0555631 -0.0308727 ...
0.0375495 -0.0474079 0.0450424 -0.0822513 -0.00803252 ...
# Matrix of size [10,10] [0x10fe150 data= 0x10fe210]
Once the trainer
is built, it is also possible to do introspection for getting
or modify ANN components.
number = trainer:count_components( [pattern=.*] )
This method returns the number of ANN components in the current trainer
.
Optionally the method receives a Lua pattern filtering the counting process
to only the components whom name matches the pattern.
> = trainer:count_components()
3
object = trainer:get_component()
This method returns the root ANN component, which was given to the constructor.
> = trainer:get_component()
instance 0x1175430 of ann.components.hyperplane
number = trainer:component(name)
This method returns an ANN component object given its name
.
> = trainer:component("hyperplane")
instance 0x1175430 of ann.components.base
object = trainer:get_weights_of(component_name)
This method returns the matrix
object which belongs to the given
component_name
.
> = trainer:get_weights_of("hyperplane")
nil
> = type( trainer:get_weights_of("c-w1") )
matrix
> = type( trainer:get_weights_of("c-b1") )
matrix
... = trainer:iterate_components( [pattern=.*] )
This method returns a Lua iterator function which iterates over all the
ANN components which name matches the given pattern
.
> for name,c in trainer:iterate_components() do print(name,c) end
hyperplane instance 0x1175430 of ann.components.base
c-w1 instance 0xfcf5b0 of ann.components.base
c-b1 instance 0xff3c90 of ann.components.base
> iterator( trainer:iterate_components("c-w.*") ):apply(print)
c-w1 instance 0xfcf5b0 of ann.components.base
The following methods are shortcuts to modify hyperparameters of the optimizer
object.
desc = trainer:has_option(option)
Returns the description of the given option
if it exists at the optimizer
.
Otherwise it returns nil
.
trainer:set_option(option, value)
Sets the given option
name to the given value
. Throws an error in case the
option doesn't exists at the optimizer
.
value = trainer:get_option(option)
Returns the value of a given option
name, or throws an error if the option
doesn't exits at the optimizer
.
trainer:set_layerwise_option(pattern, option, value)
This method needs that the trainer
was in built state. It traverses
all the connection weight objects which name matches the given pattern
string,
and sets its layer-wise option
name to the given value
.
value = trainer:get_option_of(name, option)
This method returns the option
value which applies to the given connection
weight object name
.
The following methods are prepared to work with a bunch of patterns (mini-batch). They do one batch step of the algorithms, and could be rewritten to do specific things.
mu, matrix = trainer:train_step(input, target, loss, optimizer)
This method executes one training step, using the given data:
-
input
is atoken
with a bunch (mini-batch) of data, usually it is amatrix
instance, where rows are patterns and columns features. -
target
is atoken
with a bunch (mini-batch) of data. Usually it is amatrix
instance. -
loss
is aann.loss
function object. It is optional, if not given it uses the loss function object instantiated at thetrainer
object. -
optimizer
is anann.optimizer
object. It is optional, if not given it uses the optimizer object instantiated at thetrainer
object.
The method returns two values:
-
mu
is the mean of the loss function at the given batch of patterns. -
matrix
is a one-dimensional matrix with the loss of every pattern.
> mean,loss_matrix = trainer:train_step(input, target)
mu, matrix = trainer:validate_step(input, target, loss)
This method executes one validate step, using the given data:
-
input
is atoken
with a bunch (mini-batch) of data. It can be amatrix
instance. -
target
is atoken
with a bunch (mini-batch) of data. It can be amatrix
instance. -
loss
is aann.loss
function object. It is optional, if not given it uses the loss function object instantiated at thetrainer
object.
The method returns two values:
-
mu
is the mean of the loss function at the given batch of patterns. -
matrix
is a one-dimensional matrix with the loss of every pattern.
The validate step evaluates the performance of the ANN component using the given loss function, but it doesn't train the parameters.
> mean,loss_matrix = trainer:validate_step(input, target)
g, mu, mat = trainer:compute_gradients_step(i, t, l [, g ])
This method compute the gradients of the given data, but doesn't train the parameters. The given gradients could be used to perform a manual training or gradient checking.
The arguments are:
-
input
is atoken
with a bunch (mini-batch) of data. It is usually amatrix
instance. -
target
is atoken
with a bunch (mini-batch) of data. It is usually amatrix
instance. -
loss
is aann.loss
function object. It is optional, if not given it uses the loss function object instantiated at thetrainer
object. -
gradients
is a dictionary with the gradient matrices of every connection weights object. It is optional, if not given, the matrices will be allocated, if given, the allocation could be avoided.
The method returns three values:
-
gradients
the gradients dictionary. -
mu
is the mean of the loss function at the given batch of patterns. -
matrix
is a one-dimensional matrix with the loss of every pattern.
> gradients,mean,loss_mat = trainer:compute_gradients_step(input, target)
boolean = trainer:grad_check_step(i, t [, boolean [, loss ] ] )
This method compute the gradients of the given data, and executes a gradient checking algorithm using numerical differentiation. The arguments are:
-
input
is atoken
with a bunch (mini-batch) of data. It is usually amatrix
instance. -
target
is atoken
with a bunch (mini-batch) of data. It is usually amatrix
instance. -
boolean
, iftrue
, it indicates high verbosity. It is optional, by default isfalse
. -
loss
is aann.loss
function object. It is optional, if not given it uses the loss function object instantiated at thetrainer
object.
The method returns a boolean indicating if the gradient checking algorithm success or fails.
> trainer:grad_check_step(input, target) or error("Gradients checking fails")
This methods perform a traversal over a dataset with a large number of patterns,
training or evaluating the model. The dataset is divided into batches of
bunch_size
size.
loss_mu,loss_var = trainer:train_dataset{ ... }
This method is used to train by using a given dataset. Different training schedules are possible, depending on the parameters given to the method. In any case, this methods return two values:
-
The mean of the loss function over all the patterns.
-
The variance of the loss function over all the patterns.
The following training schedules are available:
-
Training with all the patterns, in sequential way: the fields of the given table are:
-
input_dataset
adataset
with the data for input of ANN components. -
output_dataset
adataset
with the data for target outputs of ANN components (supervision). -
bunch_size
the mini-batch size, optional parameter. If not given, the bunch size instantiated at thetrainer
will be used. -
loss
a loss function object, optional parameter. If not given, the loss instantiated at thetrainer
will be used. -
optimizer
an optimizer object, optional parameter. If not given, the optimizer instantiated at thetrainer
will be used.
-
-
Training with all the patterns in shuffled way: the fields of the given table are:
-
input_dataset
adataset
with the data for input of ANN components. -
output_dataset
adataset
with the data for target outputs of ANN components (supervision). -
random
arandom
object instance, used to shuffle the patterns. -
bunch_size
the mini-batch size, optional parameter. If not given, the bunch size instantiated at thetrainer
will be used. -
loss
a loss function object, optional parameter. If not given, the loss instantiated at thetrainer
will be used. -
optimizer
an optimizer object, optional parameter. If not given, the optimizer instantiated at thetrainer
will be used.
-
-
Training with replacement: the fields of the given table are:
-
input_dataset
adataset
with the data for input of ANN components. -
output_dataset
adataset
with the data for target outputs of ANN components (supervision). -
random
arandom
object instance, used to shuffle the patterns. -
replacement
a given number with the size of the replacement. -
bunch_size
the mini-batch size, optional parameter. If not given, the bunch size instantiated at thetrainer
will be used. -
loss
a loss function object, optional parameter. If not given, the loss instantiated at thetrainer
will be used. -
optimizer
an optimizer object, optional parameter. If not given, the optimizer instantiated at thetrainer
will be used.
-
-
Training with distribution: the fields of the given table are:
-
distribution
is an array of tables, where each table contains: -input_dataset
adataset
with the data for input of ANN components. -output_dataset
adataset
with the data for target outputs of ANN components (supervision). -prob
a number with the probability of taken a pattern from this data source. -
random
arandom
object instance, used to shuffle the patterns. -
replacement
a given number with the size of the replacement. -
bunch_size
the mini-batch size, optional parameter. If not given, the bunch size instantiated at thetrainer
will be used. -
loss
a loss function object, optional parameter. If not given, the loss instantiated at thetrainer
will be used. -
optimizer
an optimizer object, optional parameter. If not given, the optimizer instantiated at thetrainer
will be used.
-
loss_mu,loss_var = trainer:validate_dataset{ ... }
This method is used to validate the model by using a given dataset. Different validation schedules are possible, depending on the parameters given to the method. In any case, this method returns two values:
-
The mean of the loss function over all the patterns.
-
The variance of the loss function over all the patterns.
The following validation schedules are available:
-
Validate with all the patterns, in sequential way: the fields of the given table are:
-
input_dataset
adataset
with the data for input of ANN components. -
output_dataset
adataset
with the data for target outputs of ANN components (supervision). -
bunch_size
the mini-batch size, optional parameter. If not given, the bunch size instantiated at thetrainer
will be used. -
loss
a loss function object, optional parameter. If not given, the loss instantiated at thetrainer
will be used.
-
-
Validate with replacement: the fields of the given table are:
-
input_dataset
adataset
with the data for input of ANN components. -
output_dataset
adataset
with the data for target outputs of ANN components (supervision). -
random
arandom
object instance, used to shuffle the patterns. -
replacement
a given number with the size of the replacement. -
bunch_size
the mini-batch size, optional parameter. If not given, the bunch size instantiated at thetrainer
will be used. -
loss
a loss function object, optional parameter. If not given, the loss instantiated at thetrainer
will be used.
-
output_ds = trainer:use_dataset{ ... }
This method receives a table with a one or two datasets and computes the output of the ANN component for every pattern. Note that this method has a large use of memory, because it needs a dataset where to store the ANN output for every pattern. Please, be careful when using it.
It receives a table with fields:
-
input_dataset
adataset
with the input data for the ANN component. -
output_dataset
adataset
with enough space to store the output of the ANN component for every pattern in theinput_dataset
. If not given, theoutput_dataset
will be allocated automatically with the required size.
This method returns the output_dataset
with the produced data.
boolean = trainer:grad_check_dataset{ ... }
This two classes are useful to build a training loop with a default stopping criterion.
trainable.train_wo_validation
class
This class implements the training function without validation, using a stopping criterion based on percentage of improvement in training or in number of epochs. The following methods are defined.
train_func = trainable.train_wo_validation{ ... }
The constructor, which receives a table with the following fields:
-
min_epochs=1
: the minimum number of epochs of the training. It is optional. -
max_epochs
: the maximum number of epochs of the training, ifmin_epochs==max_epochs
then stopping criteria will be number of epochs. -
percentage_stopping_criterion=0.01
: a number in range[0,1]
indicating the threshold for the percentage of improvement in training loss between two consecutive epochs. If the train loss improvement is less than this number, the training will stops. It is an optional field. -
first_epoch=1
: indicates the number of the first epoch. It is optional.
> -- instance using percentage_stopping_criterion=0.01 (default value)
> train_func = trainable.train_wo_validation{ max_epochs = 100 }
boolean = train_func:execute(epoch_function)
This method executes one epoch step. It
is the most important method. It receives an epoch function, which is a
closure with the responsibility of perform training with one epoch, and it
must returns two values: the trained model and the training loss. The
method returns true
or false
depending in if the stopping criterion is
satisfied or not.
> while train_func:execute(function()
local tr = trainer:train_dataset(datosentrenar)
return trainer,tr
end) do
print(train_func:get_state_string())
end
train_func:set_param(name,value)
This method modifies the value of a parameter previously given at the constructor (or used with its default value).
value = train_func:get_param(name)
This method returns the value of a parameter previously given at the constructor (or used with its default value).
epoch,tr_loss,tr_improvement,last=train_func:get_state()
This method
returns the internal state of the object. last
is the trained model returned
by the last call to epoch function.
state = train_func:get_state_table()
This method returns a table with the following fields:
-
state.current_epoch
: the current epoch. -
state.train_error
: the train loss at last epoch. -
state.train_improvement
: the train loss relative improvement. -
state.last
: the trained model returned by the last call to epoch function.
string = train_func:get_state_string()
This method returns a string for printing purposes, with the following format:
string.format("%5d %.6f %.6f",
state.current_epoch,
state.train_error,
state.train_improvement)
Finally, here is a code example showing how to use this class:
> trainer = trainable.supervised_trainer(thenet, ann.loss.mse(), 64)
> train_func = trainable.train_wo_validation{ max_epochs = 100 }
> while train_func:execute(function()
local tr = trainer:train_dataset(training_data)
return trainer,tr
end) do
print(train_func:get_state_string())
util.serialize({ train_func, training_data.shuffle }, "training.lua")
end
The following is an example of loading previously saved object:
> train_func,shuffle = util.deserialize("training.lua")
> trainer = train_func:get_state_table().last
> thenet = trainer:get_component()
> training_data.shuffle = shuffle
> while train_func:execute(function()
local tr = trainer:train_dataset(training_data)
return trainer,tr
end) do
print(train_func:get_state_string())
util.serialize({ train_func, training_data.shuffle }, "training.lua")
end
trainable.train_holdout_validation
class
This class implements the training function with a holdout validation set, using a stopping criterion based on validation or error in number of epochs. This object follows the Pocket Algorithm, so, it keeps the model which has the best validation loss during the training. A tolerance in the relative error could be used to decided a minimum improvement to take the model as the best. The following methods are defined.
train_func = trainable.train_holdout_validation{ ... }
The constructor, which receives a table with the following fields:
-
min_epochs=1
: the minimum number of epochs of the training. It is optional. -
max_epochs
: the maximum number of epochs of the training, ifmin_epochs==max_epochs
then stopping criteria will be number of epochs. -
epochs_wo_validation=0
: the number of epochs where validation loss is ignored, so the best model is the last given. It is optional. -
stopping_criterion=function() return false end
: a stopping criterion function. It will be used to determine when the training must be stopped. The given function is called given it a table with the output ofget_state_table()
method. It is optional. Basic criterion functions are defined intrainable
table, and described below this section. -
first_epoch=1
: indicates the number of the first epoch. It is optional. -
tolerance=0
: thetolerance>=0
is the minimum relative difference to take the current validation loss as the best. It is optional.
> criterion = trainable.stopping_criteria.make_max_epochs_wo_imp_relative(2)
> train_func = trainable.train_holdout_validation{
stopping_criterion = criterion,
max_epochs = max_epochs
}
boolean = train_func:execute(epoch_function)
This method executes one epoch step. It
is the most important method. It receives an epoch function, which is a
closure with the responsibility of perform training with one epoch, and it
must returns three values: the trained model, the training loss and the
validation loss. The method returns true
or false
depending in if the
stopping criterion is satisfied or not.
> while train_func:execute(function()
local tr = trainer:train_dataset(training_data)
local va = trainer:validate_dataset(validation_data)
return trainer,tr,va
end) do
print(train_func:get_state_string())
local state = train_func:get_state_table()
if state.best_epoch == state.current_epoch then
util.serialize({ train_func, training_data.shuffle }, "training.lua")
end
end
train_func:set_param(name,value)
This method modifies the value of a parameter previously given at the constructor (or used with its default value).
value = train_func:get_param(name)
This method returns the value of a parameter previously given at the constructor (or used with its default value).
epoch,tr_loss,va_loss,...=train_func:get_state()
This method returns the internal state of the object: epoch, training loss, validation loss, best epoch, best validation loss, best model clone, last given model.
state = train_func:get_state_table()
This method returns a table with the following fields:
-
state.current_epoch
: the current epoch. -
state.train_error
: the train loss at last epoch. -
state.validation_error
: the validation loss at last epoch. -
state.best_epoch
: the epoch where the best validation loss where found. -
state.best_val_error
: the validation loss at the best epoch. -
state.best
: the trained model which achieves the best validation error. -
state.last
: the trained model returned by the last call to epoch function.
string = train_func:get_state_string()
This method returns a string for printing purposes, with the following format:
string.format("%5d %.6f %.6f %5d %.6f",
state.current_epoch,
state.train_error,
state.validation_error,
state.best_epoch,
state.best_val_error)
Finally, here is a code example showing how to use this class:
> trainer = trainable.supervised_trainer(thenet, ann.loss.mse(), 64)
> criterion = trainable.stopping_criteria.make_max_epochs_wo_imp_relative(2)
> train_func = trainable.train_holdout_validation{
stopping_criterion = criterion,
max_epochs = max_epochs
}
> while train_func:execute(function()
local tr = trainer:train_dataset(training_data)
local va = trainer:validate_dataset(validation_data)
return trainer,tr,va
end) do
print(train_func:get_state_string())
local state = train_func:get_state_table()
if state.best_epoch == state.current_epoch then
util.serialize({ train_func, training_data.shuffle }, "training.lua")
end
end
For holdout-validation scheme, exists two predefined stopping criteria, which are function builders (they return the function used as criterion):
-
trainable.stopping_criteria.make_max_epochs_wo_imp_absolute
: which receives a constant indicating the maximum number of epochs without improve validation. A tipical value is between 10 and 20, depending in the task. -
trainable.stopping_criteria.make_max_epochs_wo_imp_relative
: which receives a constant indicating the maximum value for current_epoch/best_epoch. A tipical value for this is 2.
This two criteria could be used as this:
train_func = trainable.train_holdout_validation{
...
stopping_criterion = trainable.stopping_criteria.make_max_epochs_wo_imp_relative(2),
...
}
Also you can create your own stopping criterion, which is a function which receives a table:
train_func = trainable.train_holdout_validation{
...
stopping_criterion = function(t)
-- t contains this fields:
-- * current_epoch
-- * best_epoch
-- * best_val_error
-- * train_error
-- * validation_error
return true IF STOPPING CRITERIA(t) IS TRUE
end,
...
}
The class trainable.supervised_trainer
uses some generic dataset iterator
functions, available for the user if needed. Two functions are available:
trainable.dataset_pair_iterator
and trainable.dataset_multiple_iterator
. The
first is a wrapper around the second one. This iterators could perform
different traverse methods, depending in the given parameters.
Lua iterator = trainable.dataset_pair_iterator{ ... }
This iterator performs a synchronized traversal of two given datasets (normally
it is a pair input/output). The function returns a Lua iterator which returns
three values every time it is called: input
pattern (a token
, usually a
matrix
instance), output
pattern (a token
, usually a matrix
instance),
and a Lua table with the indexes of the patterns taken in the bunch.
The available traversal modes are:
-
Traverse all the patterns, in sequential way: the fields of the given table are:
-
input_dataset
adataset
with the data for input of ANN components. -
output_dataset
adataset
with the data for target outputs of ANN components (supervision). -
bunch_size
the mini-batch size.
-
-
Traverse all the patterns in shuffled way: the fields of the given table are:
-
input_dataset
adataset
with the data for input of ANN components. -
output_dataset
adataset
with the data for target outputs of ANN components (supervision). -
shuffle
arandom
object instance, used to shuffle the patterns. -
bunch_size
the mini-batch size.
-
-
Traverse with replacement: the fields of the given table are:
-
input_dataset
adataset
with the data for input of ANN components. -
output_dataset
adataset
with the data for target outputs of ANN components (supervision). -
shuffle
arandom
object instance, used to shuffle the patterns. -
replacement
a given number with the size of the replacement. -
bunch_size
the mini-batch size.
-
-
Traverse with distribution: the fields of the given table are:
-
distribution
is an array of tables, where each table contains: -input_dataset
adataset
with the data for input of ANN components. -output_dataset
adataset
with the data for target outputs of ANN components (supervision). -prob
a number with the probability of taken a pattern from this data source. -
shuffle
arandom
object instance, used to shuffle the patterns. -
replacement
a given number with the size of the replacement. -
bunch_size
the mini-batch size.
-
> ds_params = { input_dataset = my_input_ds, output_dataset = my_output_ds }
> for input,output,idxs in trainable.dataset_pair_iterator(ds_params) do
-- you can give the input/output to an ANN and loss function
print(input,output,idxs)
end
Lua iterator = trainable.dataset_multiple_iterator{ ... }
This iterator performs a synchronized traversal of any number given datasets (normally it is a pair input/output). The function returns a Lua iterator which returns as many values as the number of given datasets plus one: one pattern for each dataset, plus a Lua table with the indexes of the patterns taken in the bunch.
The available traversal modes are:
-
Traverse all the patterns, in sequential way: the fields of the given table are:
-
datasets
: a Lua table with the list of dataset for traversal. -
bunch_size
the mini-batch size.
-
-
Traverse all the patterns in shuffled way: the fields of the given table are:
-
datasets
: a Lua table with the list of dataset for traversal. -
shuffle
arandom
object instance, used to shuffle the patterns. -
bunch_size
the mini-batch size.
-
-
Traverse with replacement: the fields of the given table are:
-
datasets
: a Lua table with the list of dataset for traversal. -
shuffle
arandom
object instance, used to shuffle the patterns. -
replacement
a given number with the size of the replacement. -
bunch_size
the mini-batch size.
-
-
Traverse with distribution: the fields of the given table are:
-
distribution
is an array of tables, where each table contains: -datasets
: a Lua table with the list of dataset for traversal. -prob
a number with the probability of taken a pattern from this data source. -
shuffle
arandom
object instance, used to shuffle the patterns. -
replacement
a given number with the size of the replacement. -
bunch_size
the mini-batch size.
-
> ds_params = { datasets = { my_ds1, my_ds2, my_ds3 } }
> for token1,token2,token3,idxs in trainable.dataset_multiple_iterator(ds_params) do
-- you can give the token1,token2,token3 to an ANN or a loss function
print(token1,token2,token3,idxs)
end