forked from devray/mbICA
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnonlinearities.h
98 lines (87 loc) · 2.12 KB
/
nonlinearities.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
/**
* @file
* @author Pawel Zubrycki <[email protected]>
* @author Stanisław Janikowski <[email protected]>
*
* @section DESCRIPTION
*
* File with preimplemented nonlinearities used in FastICA algorithm.
*
* @note New nonlinearites don't have to inherit from Nonlinearity class.
* They just have to provide G() and dG() methods.
**/
#ifndef NONLINEARITIES_H
#define NONLINEARITIES_H
#include<armadillo>
namespace mbica {
/// namespace for available nonlinearities
namespace nonlinearities {
/**
* Base class for preimplemented nonlinearities.
*/
class Nonlinearity {
public:
/// Method returning value of function.
arma::mat G() { return g_; }
/// Method returning value of derivative of function.
arma::mat dG() { return dg_; }
protected:
/// Protected constructor to prevent from creating Nonlinearity objects.
Nonlinearity(){}
protected:
arma::mat g_;
arma::mat dg_;
};
/**
* Class implementing power function on matrix elements.
*
* @param a Exponent.
**/
template<int a=3>
class Pow: public Nonlinearity{
public:
Pow(arma::mat X){
g_ = arma::pow(X, a);
dg_ = a*arma::pow(X, a-1);
}
};
/**
* Class implementing tanh nonlinearity.
*
* @note Class uses fraction a/b as a parameter.
*
* @param a Numerator of parameter.
* @param b Denominator of parameter.
**/
template<int a=1, int b=1>
class TanH: public Nonlinearity{
public:
TanH(arma::mat X){
double c = double(a)/b;
g_ = arma::tanh(c * X);
dg_ = c * (1 - pow(g_, 2));
}
};
/**
* Class implementing gauss nonlinearity.
*
* @note Class uses fraction a/b as a parameter.
*
* @param a Numerator of parameter.
* @param b Denominator of parameter.
**/
template<int a=1, int b=1>
class Gauss: public Nonlinearity{
public:
Gauss(arma::mat X){
double c = double(a)/b;
arma::mat ex = arma::exp(-0.5 * c * arma::pow(X, 2));
g_ = X % ex;
dg_ = (1 - c * arma::pow(X, 2)) % ex;
}
};
/// Typedef to provide compatibility with Octave packet.
typedef Pow<2> Skew;
}
}
#endif // NONLINEARITIES_H