Skip to content
Paco Zamora Martinez edited this page Jun 14, 2015 · 29 revisions

Introduction

Related with the module require("aprilann.trainable")

This package implements the class trainable.supervised_trainer, which is a powerful tool when you want to train ANNs following standard algorithms. Additionally it implements generic functions and iterators to implement training loops (dataset iterator, and training function).

If you want to do some specific tricks, it is possible to modify in your script the methods described here, or to re-implement the functionality that you need.

The supervised trainer class

The class trainable.supervised_trainer is the most important piece in this package. This class knows the standard API of ANN components, loss functions, optimizers, datasets, and matrix objects, so, this class use them in the correct way to train ANNs following standard algorithms.

Constructor

The construction of a trainer needs at least an ANN component object, but optionally, it is possible to indicate loss function, bunch size (mini-batch) and optimizer object:

> -- a linear component (w1 and b1 are the weight names)
> c = ann.components.hyperplane{ input=10, output=10,
                                 name                = "hyperplane",
                                 dot_product_name    = "c-w1",
                                 bias_name           = "c-b1",
                                 bias_weights        = "b1",
                                 dot_product_weights = "w1" }
> -- a trainer
> trainer = trainable.supervised_trainer(c)

The arguments of the constructor are positional, described here:

  • ANN component: an instance of ann.components.base() or any sub-class of it. The trainer uses a reference to the given object. It is a mandatory argument.

  • Loss function: an instance of a sub-class of ann.loss. The trainer uses a reference to the given object. If not given, it will be mandatory in train_* and validate_* methods. By default is nil.

  • Bunch-size (mini-batch): a number with the batch size used to compute gradients. Values between 32 and 1024 are usual, depending on the optimizer and in the task. If not given, it will be mandatory in train_*, validate_*, or use_* methods. By default is nil.

  • Optimizer: an instance of a sub-class of ann.optimizer. The trainer uses a reference to the given object. If not given, the default optimizer is ann.optimizer.sgd().

  • Smooth flag: a boolean value indicating if the gradients must be smoothed. In case of true, gradients will be multiplied by 1/sqrt(bunch_size+K), being K the number of times the corresponding weights matrix is shared across the ANN. By default this parameter is true.

> -- without bunch-size, without loss function,
> -- using default optimizer, smooth is true
> trainer = trainable.supervised_trainer(c)
> -- without bunch-size, using default optimizer, smooth is true
> trainer = trainable.supervised_trainer(c, ann.loss.mse())
> -- using default optimizer, smooth is true
> trainer = trainable.supervised_trainer(c, ann.loss.mse(), 64)
> -- smooth is true
> trainer = trainable.supervised_trainer(c, ann.loss.mse(), 64,
                                         ann.optimizer.rprop())
> -- all arguments given										 
> trainer = trainable.supervised_trainer(c, ann.loss.mse(), 64,
                                         ann.optimizer.rprop(), false)

Get and set of loss function and optimizer

set_loss_function

trainer:set_loss_function(loss)

get_loss_function

loss = trainer:get_loss_function()

set_optimizer

trainer:set_optimizer(optimizer)

get_optimizer

optimizer = trainer:get_optimizer()

Building the ANN

Once the trainer is instantiated, it is mandatory to build the ANN by calling the named method.

> trainer:build()

The build method could be called without arguments, or giving a table with the following optional fields:

  • input: a number with the input size of the model. If not given, it will be set to the input of the given ANN component. If given, this size will be used as sanity check, so, it must be equal to the input size of the ANN component.

  • output: a number with the output size of the model. Idem as previous field, but with ANN component output.

  • weights: a dictionary (table) with strings as keys and instances of matrix as values. If given, the connection objects used by the components will be taken from this dictionary, so, the sizes must be equals. It could be used to initialize certain components with a given initial parameters.

> trainer:build{
                 input = 10, output = 10,
                 weights = {
				   -- ann.connections allocates memory for weight matrices
                   w1 = ann.connections{ input=10, output=10 },
                   b1 = ann.connections{ input=1,  output=10 },
                 }
               }
> -- giving a table, initialized using its constructor
> trainer:build{
                 input = 10, output = 10,
                 weights = {
                   w1 = ann.connections{ input=10, output=10 },
                   b1 = ann.connections{ input=1,  output=10 },
                 }
               }

Once the build method is called, it is recommended to not modify the structure of the ANN component, if you need so, after your modifications the build method must be called again. The connection weights must be modified after calling build, initializing them in some way. Another possibility is to deserialize a previously serialized trainer.

Serialization/deserialization

It is possible to save a trainer object, always in built state, using util.serialize() and util.deserialize() functions.

> util.serialize(trainer, "mytrainer-binary.net")
> trainer2 = util.deserialize("mytrainer-binary.net")

Connection weight methods

randomize_weights

trainer:randomize_weights{ ... }

The connection weights could be initialized randomly by using the method trainer:randomize_weights. This method receives a table with the following fields:

  • inf: a number with the inferior range bound. It is a mandatory field.

  • sup: a number with the superior range bound. It is a mandatory field.

  • random: a random object. It is a mandatory field.

  • use_fanin: a boolean value indicating if apply a factor fan-in of each layer for its initialization. It is an optional field. By default is false.

  • use_fanout: a boolean value indicating if apply a factor fan-out of each layer for its initialization. It is an optional field. By default is false

  • name_match: a string with a Lua pattern used to filter which connection objects will be initialized. It is an optional field, by default is .*.

The weights will be initilized in the range [c*inf, c*sup] where c is a factor which depends on use_fanin and use_fanout arguments:

  • If none given, then c=1 for all weight layers.

  • If given only use_fanin=true, then c is computed depending in the fan-in of each layer, being c = 1/sqrt(fanin).

  • If given only use_fanout=true, then c is computed depending in the fan-out of each layer, being c = 1/sqrt(fanout).

  • If both given, use_fanin=true and use_fanout=true, then c is computed depending in the fan-in and fan-out of each layer, being c = 1/sqrt(fanin + fanout).

> -- initilize only bias weights
> trainer:randomize_weights{
                             inf = -0.1,
                             sup =  0.1,
                             random = random(213924),
                             name_match = "b.*",
                           }
> -- initilize only non-bias weights
> trainer:randomize_weights{
                             inf = -0.1,
                             sup =  0.1,
                             random = random(213924),
                             name_match = "w.*",
                             use_fanin   = true,
                             use_fanout  = true,
                           }
> -- initilize all the weights
> trainer:randomize_weights{
                             inf = -0.1,
                             sup =  0.1,
                             random = random(213924),
                           }

Once the trainer is built, it is possible to do some introspection in order to modify or execute methods of connection weights.

count_weights

number = trainer:count_weights( [pattern=.*] )

This method returns the number of connection weights in the current trainer. Optionally the method receives a Lua pattern filtering the counting process to only the weights whom name matches the pattern.

> = trainer:count_weights()
2

weights

number = trainer:weights(name)

This method returns the matrix object with the given name.

> w1 = trainer:weights("w1")
> = type(w1)
matrix
> b1 = trainer:weights("b1")
> = type(b1)
matrix

get_components_of

table = trainer:get_components_of(weights_name)

This method returns a table with ann.components objects which share the given weights_name connection weights. It returns a table because a connection weights object could be shared by more than one ANN components.

> iterator(ipairs( trainer:get_components_of("w1") )):apply(print)
1   instance 0xfcf5b0 of ann.components.base
> iterator(ipairs( trainer:get_components_of("b1") )):apply(print)
1   instance 0xff3c90 of ann.components.base

get_weights_table

table = trainer:get_weights_table()

This method returns a dictionary (table) with all the connection weights. This dictionary has the same structure as the weights field of trainer:build(...) method.

> weights_table = trainer:get_weights_table()
> print(weights_table)
table: 0x1310a00

... = trainer:iterate_weights( [pattern=.*] )

This method returns a Lua iterator function which iterates over all the connection weights which name matches the given pattern.

> for cnn_name,cnn in trainer:iterate_weights() do print(cnn_name) end
b1
w1
> iterator( trainer:iterate_weights("w.*") ):select(1):apply(print)
w1

norm2

number = trainer:norm2([pattern=.*])

This method computes the 2-norm of the connection weight objects whom name matches the given pattern.

> = trainer:norm2()
0.24416591227055
> = trainer:norm2("b.")
0
> = trainer:norm2("w.")
0.24416591227055

size

number = trainer:size()

This methods returns the number of parameters (weights) in the current ANN component.

> = trainer:size()
110

Example

By using these methods it is possible to manipulate by-hand the connection weights, as in the following example, which initializes to zero the bias connections:

> for _,cnn in trainer:iterate_weights("b.*") do cnn:zeros() end

REMEMBER that connection weight objects are matrix instances.

The following code shows at screen all the weight matrices:

> iterator(trainer:iterate_weights()):apply(print)
 0         
 0         
 0         
 0         
 0         
 0         
 0         
 0         
 0         
 0         
# Matrix of size [10,1] [0x10aca60 data= 0x10acb20]
 0.0575503   0.0516265  -0.0306808   0.035404    0.0118243  ...
 0.0281929   0.0877731   0.0842627  -0.0379949  -0.091877   ...
-0.0332023   0.0576623  -0.0335078   0.0251189  -0.0578111  ...
-0.0335119   0.0162495  -0.00910386 -0.0949801   0.00303258 ...
-0.0361652  -0.0389352   0.0628194  -0.0622919  -0.0206459  ...
 0.0583717   0.0910834  -0.0889903  -0.0142328  -0.0750175  ...
-0.0895628   0.0412171   0.0308301  -0.0680314   0.0948681  ...
-0.00439932  0.0975324   0.00736945  0.013484   -0.079681   ...
-0.0859327   0.0332012   0.0374489  -0.0555631  -0.0308727  ...
 0.0375495  -0.0474079   0.0450424  -0.0822513  -0.00803252 ...
# Matrix of size [10,10] [0x10fe150 data= 0x10fe210]

Component methods

Once the trainer is built, it is also possible to do introspection for getting or modify ANN components.

count_components

number = trainer:count_components( [pattern=.*] )

This method returns the number of ANN components in the current trainer. Optionally the method receives a Lua pattern filtering the counting process to only the components whom name matches the pattern.

> = trainer:count_components()
3

get_components_of

object = trainer:get_component()

This method returns the root ANN component, which was given to the constructor.

> = trainer:get_component()
instance 0x1175430 of ann.components.hyperplane

component

number = trainer:component(name)

This method returns an ANN component object given its name.

> = trainer:component("hyperplane")
instance 0x1175430 of ann.components.base

get_weights_of

object = trainer:get_weights_of(component_name)

This method returns the matrix object which belongs to the given component_name.

> = trainer:get_weights_of("hyperplane")
nil
> = type( trainer:get_weights_of("c-w1") )
matrix
> = type( trainer:get_weights_of("c-b1") )
matrix

iterate_components

... = trainer:iterate_components( [pattern=.*] )

This method returns a Lua iterator function which iterates over all the ANN components which name matches the given pattern.

> for name,c in trainer:iterate_components() do print(name,c) end
hyperplane  instance 0x1175430 of ann.components.base
c-w1    instance 0xfcf5b0 of ann.components.base
c-b1    instance 0xff3c90 of ann.components.base
> iterator( trainer:iterate_components("c-w.*") ):apply(print)
c-w1    instance 0xfcf5b0 of ann.components.base

Optimizer methods

The following methods are shortcuts to modify hyperparameters of the optimizer object.

has_option

desc = trainer:has_option(option)

Returns the description of the given option if it exists at the optimizer. Otherwise it returns nil.

set_option

trainer:set_option(option, value)

Sets the given option name to the given value. Throws an error in case the option doesn't exists at the optimizer.

get_option

value = trainer:get_option(option)

Returns the value of a given option name, or throws an error if the option doesn't exits at the optimizer.

set_layerwise_option

trainer:set_layerwise_option(pattern, option, value)

This method needs that the trainer was in built state. It traverses all the connection weight objects which name matches the given pattern string, and sets its layer-wise option name to the given value.

get_option_of

value = trainer:get_option_of(name, option)

This method returns the option value which applies to the given connection weight object name.

One batch training, validation and gradients check

The following methods are prepared to work with a bunch of patterns (mini-batch). They do one batch step of the algorithms, and could be rewritten to do specific things.

train_step

mu, matrix = trainer:train_step(input, target, loss, optimizer)

This method executes one training step, using the given data:

  • input is a token with a bunch (mini-batch) of data, usually it is a matrix instance, where rows are patterns and columns features.

  • target is a token with a bunch (mini-batch) of data. Usually it is a matrix instance.

  • loss is a ann.loss function object. It is optional, if not given it uses the loss function object instantiated at the trainer object.

  • optimizer is an ann.optimizer object. It is optional, if not given it uses the optimizer object instantiated at the trainer object.

The method returns two values:

  • mu is the mean of the loss function at the given batch of patterns.

  • matrix is a one-dimensional matrix with the loss of every pattern.

> mean,loss_matrix = trainer:train_step(input, target)

validate_step

mu, matrix = trainer:validate_step(input, target, loss)

This method executes one validate step, using the given data:

  • input is a token with a bunch (mini-batch) of data. It can be a matrix instance.

  • target is a token with a bunch (mini-batch) of data. It can be a matrix instance.

  • loss is a ann.loss function object. It is optional, if not given it uses the loss function object instantiated at the trainer object.

The method returns two values:

  • mu is the mean of the loss function at the given batch of patterns.

  • matrix is a one-dimensional matrix with the loss of every pattern.

The validate step evaluates the performance of the ANN component using the given loss function, but it doesn't train the parameters.

> mean,loss_matrix = trainer:validate_step(input, target)

compute_gradients_step

g, mu, mat = trainer:compute_gradients_step(i, t, l [, g ])

This method compute the gradients of the given data, but doesn't train the parameters. The given gradients could be used to perform a manual training or gradient checking.

The arguments are:

  • input is a token with a bunch (mini-batch) of data. It is usually a matrix instance.

  • target is a token with a bunch (mini-batch) of data. It is usually a matrix instance.

  • loss is a ann.loss function object. It is optional, if not given it uses the loss function object instantiated at the trainer object.

  • gradients is a dictionary with the gradient matrices of every connection weights object. It is optional, if not given, the matrices will be allocated, if given, the allocation could be avoided.

The method returns three values:

  • gradients the gradients dictionary.

  • mu is the mean of the loss function at the given batch of patterns.

  • matrix is a one-dimensional matrix with the loss of every pattern.

> gradients,mean,loss_mat = trainer:compute_gradients_step(input, target)

grad_check_step

boolean = trainer:grad_check_step(i, t [, boolean [, loss ] ] )

This method compute the gradients of the given data, and executes a gradient checking algorithm using numerical differentiation. The arguments are:

  • input is a token with a bunch (mini-batch) of data. It is usually a matrix instance.

  • target is a token with a bunch (mini-batch) of data. It is usually a matrix instance.

  • boolean, if true, it indicates high verbosity. It is optional, by default is false.

  • loss is a ann.loss function object. It is optional, if not given it uses the loss function object instantiated at the trainer object.

The method returns a boolean indicating if the gradient checking algorithm success or fails.

> trainer:grad_check_step(input, target) or error("Gradients checking fails")

Dataset methods training, validation and gradients check

This methods perform a traversal over a dataset with a large number of patterns, training or evaluating the model. The dataset is divided into batches of bunch_size size.

train_dataset

loss_mu,loss_var = trainer:train_dataset{ ... }

This method is used to train by using a given dataset. Different training schedules are possible, depending on the parameters given to the method. In any case, this methods return two values:

  • The mean of the loss function over all the patterns.

  • The variance of the loss function over all the patterns.

The following training schedules are available:

  • Training with all the patterns, in sequential way: the fields of the given table are:

    • input_dataset a dataset with the data for input of ANN components.
    • output_dataset a dataset with the data for target outputs of ANN components (supervision).
    • bunch_size the mini-batch size, optional parameter. If not given, the bunch size instantiated at the trainer will be used.
    • loss a loss function object, optional parameter. If not given, the loss instantiated at the trainer will be used.
    • optimizer an optimizer object, optional parameter. If not given, the optimizer instantiated at the trainer will be used.
  • Training with all the patterns in shuffled way: the fields of the given table are:

    • input_dataset a dataset with the data for input of ANN components.
    • output_dataset a dataset with the data for target outputs of ANN components (supervision).
    • random a random object instance, used to shuffle the patterns.
    • bunch_size the mini-batch size, optional parameter. If not given, the bunch size instantiated at the trainer will be used.
    • loss a loss function object, optional parameter. If not given, the loss instantiated at the trainer will be used.
    • optimizer an optimizer object, optional parameter. If not given, the optimizer instantiated at the trainer will be used.
  • Training with replacement: the fields of the given table are:

    • input_dataset a dataset with the data for input of ANN components.
    • output_dataset a dataset with the data for target outputs of ANN components (supervision).
    • random a random object instance, used to shuffle the patterns.
    • replacement a given number with the size of the replacement.
    • bunch_size the mini-batch size, optional parameter. If not given, the bunch size instantiated at the trainer will be used.
    • loss a loss function object, optional parameter. If not given, the loss instantiated at the trainer will be used.
    • optimizer an optimizer object, optional parameter. If not given, the optimizer instantiated at the trainer will be used.
  • Training with distribution: the fields of the given table are:

    • distribution is an array of tables, where each table contains: - input_dataset a dataset with the data for input of ANN components. - output_dataset a dataset with the data for target outputs of ANN components (supervision). - prob a number with the probability of taken a pattern from this data source.
    • random a random object instance, used to shuffle the patterns.
    • replacement a given number with the size of the replacement.
    • bunch_size the mini-batch size, optional parameter. If not given, the bunch size instantiated at the trainer will be used.
    • loss a loss function object, optional parameter. If not given, the loss instantiated at the trainer will be used.
    • optimizer an optimizer object, optional parameter. If not given, the optimizer instantiated at the trainer will be used.

validate_dataset

loss_mu,loss_var = trainer:validate_dataset{ ... }

This method is used to validate the model by using a given dataset. Different validation schedules are possible, depending on the parameters given to the method. In any case, this method returns two values:

  • The mean of the loss function over all the patterns.

  • The variance of the loss function over all the patterns.

The following validation schedules are available:

  • Validate with all the patterns, in sequential way: the fields of the given table are:

    • input_dataset a dataset with the data for input of ANN components.
    • output_dataset a dataset with the data for target outputs of ANN components (supervision).
    • bunch_size the mini-batch size, optional parameter. If not given, the bunch size instantiated at the trainer will be used.
    • loss a loss function object, optional parameter. If not given, the loss instantiated at the trainer will be used.
  • Validate with replacement: the fields of the given table are:

    • input_dataset a dataset with the data for input of ANN components.
    • output_dataset a dataset with the data for target outputs of ANN components (supervision).
    • random a random object instance, used to shuffle the patterns.
    • replacement a given number with the size of the replacement.
    • bunch_size the mini-batch size, optional parameter. If not given, the bunch size instantiated at the trainer will be used.
    • loss a loss function object, optional parameter. If not given, the loss instantiated at the trainer will be used.

use_dataset

output_ds = trainer:use_dataset{ ... }

This method receives a table with a one or two datasets and computes the output of the ANN component for every pattern. Note that this method has a large use of memory, because it needs a dataset where to store the ANN output for every pattern. Please, be careful when using it.

It receives a table with fields:

  • input_dataset a dataset with the input data for the ANN component.

  • output_dataset a dataset with enough space to store the output of the ANN component for every pattern in the input_dataset. If not given, the output_dataset will be allocated automatically with the required size.

This method returns the output_dataset with the produced data.

grad_check_dataset

boolean = trainer:grad_check_dataset{ ... }

Training function loop classes

This two classes are useful to build a training loop with a default stopping criterion.

train_wo_validation

trainable.train_wo_validation class

This class implements the training function without validation, using a stopping criterion based on percentage of improvement in training or in number of epochs. The following methods are defined.

Constructor

train_func = trainable.train_wo_validation{ ... }

The constructor, which receives a table with the following fields:

  • min_epochs=1: the minimum number of epochs of the training. It is optional.

  • max_epochs: the maximum number of epochs of the training, if min_epochs==max_epochs then stopping criteria will be number of epochs.

  • percentage_stopping_criterion=0.01: a number in range [0,1] indicating the threshold for the percentage of improvement in training loss between two consecutive epochs. If the train loss improvement is less than this number, the training will stops. It is an optional field.

  • first_epoch=1: indicates the number of the first epoch. It is optional.

> -- instance using percentage_stopping_criterion=0.01 (default value)
> train_func = trainable.train_wo_validation{ max_epochs = 100 }

execute

boolean = train_func:execute(epoch_function)

This method executes one epoch step. It is the most important method. It receives an epoch function, which is a closure with the responsibility of perform training with one epoch, and it must returns two values: the trained model and the training loss. The method returns true or false depending in if the stopping criterion is satisfied or not.

> while train_func:execute(function()
               local tr = trainer:train_dataset(datosentrenar)
               return trainer,tr
             end) do
  print(train_func:get_state_string())
end

set_param

train_func:set_param(name,value)

This method modifies the value of a parameter previously given at the constructor (or used with its default value).

get_param

value = train_func:get_param(name)

This method returns the value of a parameter previously given at the constructor (or used with its default value).

get_state

epoch,tr_loss,tr_improvement,last=train_func:get_state()

This method returns the internal state of the object. last is the trained model returned by the last call to epoch function.

get_state_table

state = train_func:get_state_table()

This method returns a table with the following fields:

  • state.current_epoch: the current epoch.

  • state.train_error: the train loss at last epoch.

  • state.train_improvement: the train loss relative improvement.

  • state.last: the trained model returned by the last call to epoch function.

get_state_string

string = train_func:get_state_string()

This method returns a string for printing purposes, with the following format:

  string.format("%5d %.6f    %.6f",
                state.current_epoch,
                state.train_error,
                state.train_improvement)

Code example

Finally, here is a code example showing how to use this class:

> trainer    = trainable.supervised_trainer(thenet, ann.loss.mse(), 64)
> train_func = trainable.train_wo_validation{ max_epochs = 100 }
> while train_func:execute(function()
               local tr = trainer:train_dataset(training_data)
               return trainer,tr
             end) do
  print(train_func:get_state_string())
  util.serialize({ train_func, training_data.shuffle }, "training.lua")
end

The following is an example of loading previously saved object:

> train_func,shuffle    = util.deserialize("training.lua")
> trainer               = train_func:get_state_table().last
> thenet                = trainer:get_component()
> training_data.shuffle = shuffle
> while train_func:execute(function()
               local tr = trainer:train_dataset(training_data)
               return trainer,tr
             end) do
  print(train_func:get_state_string())
  util.serialize({ train_func, training_data.shuffle }, "training.lua")
end

train_holdout_validation

trainable.train_holdout_validation class

This class implements the training function with a holdout validation set, using a stopping criterion based on validation or error in number of epochs. This object follows the Pocket Algorithm, so, it keeps the model which has the best validation loss during the training. A tolerance in the relative error could be used to decided a minimum improvement to take the model as the best. The following methods are defined.

Constructor

train_func = trainable.train_holdout_validation{ ... }

The constructor, which receives a table with the following fields:

  • min_epochs=1: the minimum number of epochs of the training. It is optional.

  • max_epochs: the maximum number of epochs of the training, if min_epochs==max_epochs then stopping criteria will be number of epochs.

  • epochs_wo_validation=0: the number of epochs where validation loss is ignored, so the best model is the last given. It is optional.

  • stopping_criterion=function() return false end: a stopping criterion function. It will be used to determine when the training must be stopped. The given function is called given it a table with the output of get_state_table() method. It is optional. Basic criterion functions are defined in trainable table, and described below this section.

  • first_epoch=1: indicates the number of the first epoch. It is optional.

  • tolerance=0: the tolerance>=0 is the minimum relative difference to take the current validation loss as the best. It is optional.

> criterion  = trainable.stopping_criteria.make_max_epochs_wo_imp_relative(2)
> train_func = trainable.train_holdout_validation{
                 stopping_criterion = criterion,
                 max_epochs         = max_epochs
               }

execute

boolean = train_func:execute(epoch_function)

This method executes one epoch step. It is the most important method. It receives an epoch function, which is a closure with the responsibility of perform training with one epoch, and it must returns three values: the trained model, the training loss and the validation loss. The method returns true or false depending in if the stopping criterion is satisfied or not.

> while train_func:execute(function()
               local tr = trainer:train_dataset(training_data)
               local va = trainer:validate_dataset(validation_data)
               return trainer,tr,va
             end) do
  print(train_func:get_state_string())
  local state = train_func:get_state_table()
  if state.best_epoch == state.current_epoch then
    util.serialize({ train_func, training_data.shuffle }, "training.lua")
  end
end

set_param

train_func:set_param(name,value)

This method modifies the value of a parameter previously given at the constructor (or used with its default value).

get_param

value = train_func:get_param(name)

This method returns the value of a parameter previously given at the constructor (or used with its default value).

get_state

epoch,tr_loss,va_loss,...=train_func:get_state()

This method returns the internal state of the object: epoch, training loss, validation loss, best epoch, best validation loss, best model clone, last given model.

get_state_table

state = train_func:get_state_table()

This method returns a table with the following fields:

  • state.current_epoch: the current epoch.

  • state.train_error: the train loss at last epoch.

  • state.validation_error: the validation loss at last epoch.

  • state.best_epoch: the epoch where the best validation loss where found.

  • state.best_val_error: the validation loss at the best epoch.

  • state.best: the trained model which achieves the best validation error.

  • state.last: the trained model returned by the last call to epoch function.

get_state_string

string = train_func:get_state_string()

This method returns a string for printing purposes, with the following format:

  string.format("%5d %.6f %.6f    %5d %.6f",
                state.current_epoch,
                state.train_error,
                state.validation_error,
                state.best_epoch,
                state.best_val_error)

Code example

Finally, here is a code example showing how to use this class:

> trainer    = trainable.supervised_trainer(thenet, ann.loss.mse(), 64)
> criterion  = trainable.stopping_criteria.make_max_epochs_wo_imp_relative(2)
> train_func = trainable.train_holdout_validation{
                 stopping_criterion = criterion,
                 max_epochs         = max_epochs
               }
> while train_func:execute(function()
               local tr = trainer:train_dataset(training_data)
               local va = trainer:validate_dataset(validation_data)
               return trainer,tr,va
             end) do
  print(train_func:get_state_string())
  local state = train_func:get_state_table()
  if state.best_epoch == state.current_epoch then
    util.serialize({ train_func, training_data.shuffle }, "training.lua")
  end
end

Stopping criteria

For holdout-validation scheme, exists two predefined stopping criteria, which are function builders (they return the function used as criterion):

  • trainable.stopping_criteria.make_max_epochs_wo_imp_absolute: which receives a constant indicating the maximum number of epochs without improve validation. A tipical value is between 10 and 20, depending in the task.

  • trainable.stopping_criteria.make_max_epochs_wo_imp_relative: which receives a constant indicating the maximum value for current_epoch/best_epoch. A tipical value for this is 2.

This two criteria could be used as this:

train_func = trainable.train_holdout_validation{
  ...
  stopping_criterion = trainable.stopping_criteria.make_max_epochs_wo_imp_relative(2),
  ...
}

Also you can create your own stopping criterion, which is a function which receives a table:

train_func = trainable.train_holdout_validation{
  ...
  stopping_criterion = function(t)
    -- t contains this fields:
    --   * current_epoch
    --   * best_epoch
    --   * best_val_error
    --   * train_error
    --   * validation_error
    return true IF STOPPING CRITERIA(t) IS TRUE
  end,
  ...
}

Dataset iterator functions

The class trainable.supervised_trainer uses some generic dataset iterator functions, available for the user if needed. Two functions are available: trainable.dataset_pair_iterator and trainable.dataset_multiple_iterator. The first is a wrapper around the second one. This iterators could perform different traverse methods, depending in the given parameters.

dataset_pair_iterator

Lua iterator = trainable.dataset_pair_iterator{ ... }

This iterator performs a synchronized traversal of two given datasets (normally it is a pair input/output). The function returns a Lua iterator which returns three values every time it is called: input pattern (a token, usually a matrix instance), output pattern (a token, usually a matrix instance), and a Lua table with the indexes of the patterns taken in the bunch.

The available traversal modes are:

  • Traverse all the patterns, in sequential way: the fields of the given table are:

    • input_dataset a dataset with the data for input of ANN components.
    • output_dataset a dataset with the data for target outputs of ANN components (supervision).
    • bunch_size the mini-batch size.
  • Traverse all the patterns in shuffled way: the fields of the given table are:

    • input_dataset a dataset with the data for input of ANN components.
    • output_dataset a dataset with the data for target outputs of ANN components (supervision).
    • shuffle a random object instance, used to shuffle the patterns.
    • bunch_size the mini-batch size.
  • Traverse with replacement: the fields of the given table are:

    • input_dataset a dataset with the data for input of ANN components.
    • output_dataset a dataset with the data for target outputs of ANN components (supervision).
    • shuffle a random object instance, used to shuffle the patterns.
    • replacement a given number with the size of the replacement.
    • bunch_size the mini-batch size.
  • Traverse with distribution: the fields of the given table are:

    • distribution is an array of tables, where each table contains: - input_dataset a dataset with the data for input of ANN components. - output_dataset a dataset with the data for target outputs of ANN components (supervision). - prob a number with the probability of taken a pattern from this data source.
    • shuffle a random object instance, used to shuffle the patterns.
    • replacement a given number with the size of the replacement.
    • bunch_size the mini-batch size.
> ds_params = { input_dataset = my_input_ds, output_dataset = my_output_ds }
> for input,output,idxs in trainable.dataset_pair_iterator(ds_params) do
    -- you can give the input/output to an ANN and loss function
    print(input,output,idxs)
  end

dataset_multiple_iterator

Lua iterator = trainable.dataset_multiple_iterator{ ... }

This iterator performs a synchronized traversal of any number given datasets (normally it is a pair input/output). The function returns a Lua iterator which returns as many values as the number of given datasets plus one: one pattern for each dataset, plus a Lua table with the indexes of the patterns taken in the bunch.

The available traversal modes are:

  • Traverse all the patterns, in sequential way: the fields of the given table are:

    • datasets: a Lua table with the list of dataset for traversal.
    • bunch_size the mini-batch size.
  • Traverse all the patterns in shuffled way: the fields of the given table are:

    • datasets: a Lua table with the list of dataset for traversal.
    • shuffle a random object instance, used to shuffle the patterns.
    • bunch_size the mini-batch size.
  • Traverse with replacement: the fields of the given table are:

    • datasets: a Lua table with the list of dataset for traversal.
    • shuffle a random object instance, used to shuffle the patterns.
    • replacement a given number with the size of the replacement.
    • bunch_size the mini-batch size.
  • Traverse with distribution: the fields of the given table are:

    • distribution is an array of tables, where each table contains: - datasets: a Lua table with the list of dataset for traversal. - prob a number with the probability of taken a pattern from this data source.
    • shuffle a random object instance, used to shuffle the patterns.
    • replacement a given number with the size of the replacement.
    • bunch_size the mini-batch size.
> ds_params = { datasets = { my_ds1, my_ds2, my_ds3 } }
> for token1,token2,token3,idxs in trainable.dataset_multiple_iterator(ds_params) do
    -- you can give the token1,token2,token3 to an ANN or a loss function
    print(token1,token2,token3,idxs)
  end
Clone this wiki locally