@@ -92,6 +92,110 @@ function nn.Jacobian.forward(module, input, param, perturbation)
92
92
return jacobian
93
93
end
94
94
95
+ function nn .Jacobian .backwardDiagHessian (module , input , diagHessianParamName )
96
+ -- Compute the second derivatives (diagonal Hessian elements)
97
+ -- by backpropagation (using the code from hessian.lua).
98
+ --
99
+ -- This function computes the diagonal Hessian elements of the following function:
100
+ --
101
+ -- F(x_1, x_2, ..., x_n) = y_1^2/2 + y_2^2/2 + ... + y_m^2/2,
102
+ --
103
+ -- where
104
+ -- x_1, ..., x_n are the input values and parameters of the given module,
105
+ -- y_1, ..., y_m are the output values of the given module.
106
+ --
107
+ -- All x_i and y_i values are scalars here. In other words,
108
+ -- x_1, ..., x_n denote the scalar elements of the module input tensor,
109
+ -- the scalar elements of module.weight,
110
+ -- and the scalar elements of module.bias;
111
+ -- y_1, ..., y_m are the scalar elements of the module output tensor.
112
+ --
113
+ -- The diagonal Hessian elements of F are computed with respect to
114
+ -- the module input values and parameters (x_1, .., x_n).
115
+ --
116
+ -- The function F is chosen for its convenient properties:
117
+ --
118
+ -- dF / dy_i = y_i,
119
+ -- d^2F / dy_i^2 = 1.
120
+ --
121
+ -- In other words, the diagonal Hessian elements of F with respect
122
+ -- to the module OUTPUT values (y_1, ... y_m) are equal to 1.
123
+ --
124
+ -- Because of that, computing the diagonal Hessian elements of F
125
+ -- with respect to the module INPUT values and PARAMETERS (x_1, ..., x_n)
126
+ -- can be done by calling updateDiagHessianInput() and accDiagHessianParameters()
127
+ -- using a tensor of ones as diagHessianOutput.
128
+
129
+ module :forward (input )
130
+ local diagHessianOutput = module .output .new ():resizeAs (module .output ):fill (1 )
131
+
132
+ module .diagHessianWeight :zero ()
133
+ module .diagHessianBias :zero ()
134
+ module :updateDiagHessianInput (input , diagHessianOutput )
135
+ module :accDiagHessianParameters (input , diagHessianOutput )
136
+
137
+ return module [diagHessianParamName ]
138
+ end
139
+
140
+ function nn .Jacobian .linearModuleDiagHessian (module , input , gradParamName )
141
+ -- Compute the second derivatives (diagonal Hessian elements)
142
+ -- from the first derivatives for the given module
143
+ -- (without using the code from hessian.lua).
144
+ --
145
+ -- The given module is assumed to be linear with respect to its inputs and weights
146
+ -- (like nn.Linear, nn.SpatialConvolution, etc.)
147
+ --
148
+ -- This function computes the diagonal Hessian elements of the following function:
149
+ --
150
+ -- F(x_1, x_2, ..., x_n) = y_1^2/2 + y_2^2/2 + ... + y_m^2/2.
151
+ --
152
+ -- (See the the comment for nn.Jacobian.backwardDiagHessian() for explanation.)
153
+ --
154
+ -- The first derivatives of F with respect to
155
+ -- the module inputs and parameters (x_1, ..., x_n) are:
156
+ --
157
+ -- dF / dx_i = \sum_k (dF / dy_k) (dy_k / dx_i).
158
+ --
159
+ -- The second derivatives are:
160
+ --
161
+ -- d^2F / dx_i = \sum_k [(d^2F / dy_k^2) (dy_k / dx_i)^2 + (dF / dy_k) (d^2y_k / dx_i^2)].
162
+ --
163
+ -- The second derivatives of F with respect to the module outputs (y_1, ..., y_m)
164
+ -- are equal to 1, so:
165
+ --
166
+ -- d^2F / dx_i = \sum_k [(dy_k / dx_i)^2 + (dF / dy_k) (d^2y_k / dx_i^2)].
167
+ --
168
+ -- Assuming the linearity of module outputs (y_1, ..., y_m)
169
+ -- with respect to module inputs and parameters (x_1, ..., x_n),
170
+ -- we have (d^2y_k / dx_i^2) = 0,
171
+ -- and the expression finally becomes:
172
+ --
173
+ -- d^2F / dx_i = \sum_k (dy_k / dx_i)^2.
174
+ --
175
+ -- The first derivatives (dy_k / dx_i) are computed by normal backpropagation,
176
+ -- using updateGradInput() and accGradParameters().
177
+
178
+ local gradParam = module [gradParamName ]
179
+
180
+ local diagHessian = gradParam .new ():resize (gradParam :nElement ()):zero ()
181
+
182
+ module :forward (input )
183
+ local gradOutput = module .output .new ():resizeAs (module .output )
184
+ local gradOutput1D = gradOutput :view (gradOutput :nElement ())
185
+
186
+ for i = 1 ,gradOutput :nElement () do
187
+ gradOutput1D :zero ()
188
+ gradOutput1D [i ] = 1
189
+ module .gradWeight :zero ()
190
+ module .gradBias :zero ()
191
+ module :updateGradInput (input , gradOutput )
192
+ module :accGradParameters (input , gradOutput )
193
+ diagHessian :addcmul (gradParam , gradParam )
194
+ end
195
+
196
+ return diagHessian
197
+ end
198
+
95
199
function nn .Jacobian .forwardUpdate (module , input , param , perturbation )
96
200
-- perturbation amount
97
201
perturbation = perturbation or 1e-6
@@ -156,6 +260,33 @@ function nn.Jacobian.testJacobianUpdateParameters(module, input, param, minval,
156
260
return error :abs ():max ()
157
261
end
158
262
263
+ function nn .Jacobian .testDiagHessian (module , input , gradParamName , diagHessianParamName , minval , maxval )
264
+ -- Compute the diagonal Hessian elements for the same function in two different ways,
265
+ -- then compare the results and return the difference.
266
+
267
+ minval = minval or - 2
268
+ maxval = maxval or 2
269
+ local inrange = maxval - minval
270
+ input :copy (torch .rand (input :nElement ()):mul (inrange ):add (minval ))
271
+ module :initDiagHessianParameters ()
272
+ local h_bprop = nn .Jacobian .backwardDiagHessian (module , input , diagHessianParamName )
273
+ local h_linearmodule = nn .Jacobian .linearModuleDiagHessian (module , input , gradParamName )
274
+ local error = h_bprop - h_linearmodule
275
+ return error :abs ():max ()
276
+ end
277
+
278
+ function nn .Jacobian .testDiagHessianInput (module , input , minval , maxval )
279
+ return nn .Jacobian .testDiagHessian (module , input , ' gradInput' , ' diagHessianInput' , minval , maxval )
280
+ end
281
+
282
+ function nn .Jacobian .testDiagHessianWeight (module , input , minval , maxval )
283
+ return nn .Jacobian .testDiagHessian (module , input , ' gradWeight' , ' diagHessianWeight' , minval , maxval )
284
+ end
285
+
286
+ function nn .Jacobian .testDiagHessianBias (module , input , minval , maxval )
287
+ return nn .Jacobian .testDiagHessian (module , input , ' gradBias' , ' diagHessianBias' , minval , maxval )
288
+ end
289
+
159
290
function nn .Jacobian .testIO (module ,input , minval , maxval )
160
291
minval = minval or - 2
161
292
maxval = maxval or 2
0 commit comments