1
+ # Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
2
+ # LICENSE is in incl_licenses directory.
3
+
4
+ import torch
5
+ from torch import nn , sin , pow
6
+ from torch .nn import Parameter
7
+
8
+
9
+ class Snake (nn .Module ):
10
+ '''
11
+ Implementation of a sine-based periodic activation function
12
+ Shape:
13
+ - Input: (B, C, T)
14
+ - Output: (B, C, T), same shape as the input
15
+ Parameters:
16
+ - alpha - trainable parameter
17
+ References:
18
+ - This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
19
+ https://arxiv.org/abs/2006.08195
20
+ Examples:
21
+ >>> a1 = snake(256)
22
+ >>> x = torch.randn(256)
23
+ >>> x = a1(x)
24
+ '''
25
+ def __init__ (self , in_features , alpha = 1.0 , alpha_trainable = True , alpha_logscale = False ):
26
+ '''
27
+ Initialization.
28
+ INPUT:
29
+ - in_features: shape of the input
30
+ - alpha: trainable parameter
31
+ alpha is initialized to 1 by default, higher values = higher-frequency.
32
+ alpha will be trained along with the rest of your model.
33
+ '''
34
+ super (Snake , self ).__init__ ()
35
+ self .in_features = in_features
36
+
37
+ # initialize alpha
38
+ self .alpha_logscale = alpha_logscale
39
+ if self .alpha_logscale : # log scale alphas initialized to zeros
40
+ self .alpha = Parameter (torch .zeros (in_features ) * alpha )
41
+ else : # linear scale alphas initialized to ones
42
+ self .alpha = Parameter (torch .ones (in_features ) * alpha )
43
+
44
+ self .alpha .requires_grad = alpha_trainable
45
+
46
+ self .no_div_by_zero = 0.000000001
47
+
48
+ def forward (self , x ):
49
+ '''
50
+ Forward pass of the function.
51
+ Applies the function to the input elementwise.
52
+ Snake ∶= x + 1/a * sin^2 (xa)
53
+ '''
54
+ alpha = self .alpha .unsqueeze (0 ).unsqueeze (- 1 ) # line up with x to [B, C, T]
55
+ if self .alpha_logscale :
56
+ alpha = torch .exp (alpha )
57
+ x = x + (1.0 / (alpha + self .no_div_by_zero )) * pow (sin (x * alpha ), 2 )
58
+
59
+ return x
60
+
61
+
62
+ class SnakeBeta (nn .Module ):
63
+ '''
64
+ A modified Snake function which uses separate parameters for the magnitude of the periodic components
65
+ Shape:
66
+ - Input: (B, C, T)
67
+ - Output: (B, C, T), same shape as the input
68
+ Parameters:
69
+ - alpha - trainable parameter that controls frequency
70
+ - beta - trainable parameter that controls magnitude
71
+ References:
72
+ - This activation function is a modified version based on this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
73
+ https://arxiv.org/abs/2006.08195
74
+ Examples:
75
+ >>> a1 = snakebeta(256)
76
+ >>> x = torch.randn(256)
77
+ >>> x = a1(x)
78
+ '''
79
+ def __init__ (self , in_features , alpha = 1.0 , alpha_trainable = True , alpha_logscale = False ):
80
+ '''
81
+ Initialization.
82
+ INPUT:
83
+ - in_features: shape of the input
84
+ - alpha - trainable parameter that controls frequency
85
+ - beta - trainable parameter that controls magnitude
86
+ alpha is initialized to 1 by default, higher values = higher-frequency.
87
+ beta is initialized to 1 by default, higher values = higher-magnitude.
88
+ alpha will be trained along with the rest of your model.
89
+ '''
90
+ super (SnakeBeta , self ).__init__ ()
91
+ self .in_features = in_features
92
+
93
+ # initialize alpha
94
+ self .alpha_logscale = alpha_logscale
95
+ if self .alpha_logscale : # log scale alphas initialized to zeros
96
+ self .alpha = Parameter (torch .zeros (in_features ) * alpha )
97
+ self .beta = Parameter (torch .zeros (in_features ) * alpha )
98
+ else : # linear scale alphas initialized to ones
99
+ self .alpha = Parameter (torch .ones (in_features ) * alpha )
100
+ self .beta = Parameter (torch .ones (in_features ) * alpha )
101
+
102
+ self .alpha .requires_grad = alpha_trainable
103
+ self .beta .requires_grad = alpha_trainable
104
+
105
+ self .no_div_by_zero = 0.000000001
106
+
107
+ def forward (self , x ):
108
+ '''
109
+ Forward pass of the function.
110
+ Applies the function to the input elementwise.
111
+ SnakeBeta ∶= x + 1/b * sin^2 (xa)
112
+ '''
113
+ alpha = self .alpha .unsqueeze (0 ).unsqueeze (- 1 ) # line up with x to [B, C, T]
114
+ beta = self .beta .unsqueeze (0 ).unsqueeze (- 1 )
115
+ if self .alpha_logscale :
116
+ alpha = torch .exp (alpha )
117
+ beta = torch .exp (beta )
118
+ x = x + (1.0 / (beta + self .no_div_by_zero )) * pow (sin (x * alpha ), 2 )
119
+
120
+ return x
0 commit comments