Skip to content

Commit a53cb3c

Browse files
committed
Add unit tests for hessian.lua, fix bugs detected by the tests.
* Fix initialization of diagHessianBias for nn.SpatialConvolution. * Fix computing diagHessianBias for nn.SpatialFullConvolution. * Call module:forward() with the proper input before calling accGradParameters(). Without that, accDiagHessianParameters() produces incorrect results for some convolution classes. * Move duplicate code from Module.getParameters() to Module.flatten(), which is now used by both the original Module.getParameters() in Module.lua and the replacement Module.getParameters() in hessian.lua.
1 parent b7aa53d commit a53cb3c

File tree

4 files changed

+290
-128
lines changed

4 files changed

+290
-128
lines changed

Jacobian.lua

+131
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,110 @@ function nn.Jacobian.forward(module, input, param, perturbation)
9292
return jacobian
9393
end
9494

95+
function nn.Jacobian.backwardDiagHessian(module, input, diagHessianParamName)
96+
-- Compute the second derivatives (diagonal Hessian elements)
97+
-- by backpropagation (using the code from hessian.lua).
98+
--
99+
-- This function computes the diagonal Hessian elements of the following function:
100+
--
101+
-- F(x_1, x_2, ..., x_n) = y_1^2/2 + y_2^2/2 + ... + y_m^2/2,
102+
--
103+
-- where
104+
-- x_1, ..., x_n are the input values and parameters of the given module,
105+
-- y_1, ..., y_m are the output values of the given module.
106+
--
107+
-- All x_i and y_i values are scalars here. In other words,
108+
-- x_1, ..., x_n denote the scalar elements of the module input tensor,
109+
-- the scalar elements of module.weight,
110+
-- and the scalar elements of module.bias;
111+
-- y_1, ..., y_m are the scalar elements of the module output tensor.
112+
--
113+
-- The diagonal Hessian elements of F are computed with respect to
114+
-- the module input values and parameters (x_1, .., x_n).
115+
--
116+
-- The function F is chosen for its convenient properties:
117+
--
118+
-- dF / dy_i = y_i,
119+
-- d^2F / dy_i^2 = 1.
120+
--
121+
-- In other words, the diagonal Hessian elements of F with respect
122+
-- to the module OUTPUT values (y_1, ... y_m) are equal to 1.
123+
--
124+
-- Because of that, computing the diagonal Hessian elements of F
125+
-- with respect to the module INPUT values and PARAMETERS (x_1, ..., x_n)
126+
-- can be done by calling updateDiagHessianInput() and accDiagHessianParameters()
127+
-- using a tensor of ones as diagHessianOutput.
128+
129+
module:forward(input)
130+
local diagHessianOutput = module.output.new():resizeAs(module.output):fill(1)
131+
132+
module.diagHessianWeight:zero()
133+
module.diagHessianBias:zero()
134+
module:updateDiagHessianInput(input, diagHessianOutput)
135+
module:accDiagHessianParameters(input, diagHessianOutput)
136+
137+
return module[diagHessianParamName]
138+
end
139+
140+
function nn.Jacobian.linearModuleDiagHessian(module, input, gradParamName)
141+
-- Compute the second derivatives (diagonal Hessian elements)
142+
-- from the first derivatives for the given module
143+
-- (without using the code from hessian.lua).
144+
--
145+
-- The given module is assumed to be linear with respect to its inputs and weights
146+
-- (like nn.Linear, nn.SpatialConvolution, etc.)
147+
--
148+
-- This function computes the diagonal Hessian elements of the following function:
149+
--
150+
-- F(x_1, x_2, ..., x_n) = y_1^2/2 + y_2^2/2 + ... + y_m^2/2.
151+
--
152+
-- (See the the comment for nn.Jacobian.backwardDiagHessian() for explanation.)
153+
--
154+
-- The first derivatives of F with respect to
155+
-- the module inputs and parameters (x_1, ..., x_n) are:
156+
--
157+
-- dF / dx_i = \sum_k (dF / dy_k) (dy_k / dx_i).
158+
--
159+
-- The second derivatives are:
160+
--
161+
-- d^2F / dx_i = \sum_k [(d^2F / dy_k^2) (dy_k / dx_i)^2 + (dF / dy_k) (d^2y_k / dx_i^2)].
162+
--
163+
-- The second derivatives of F with respect to the module outputs (y_1, ..., y_m)
164+
-- are equal to 1, so:
165+
--
166+
-- d^2F / dx_i = \sum_k [(dy_k / dx_i)^2 + (dF / dy_k) (d^2y_k / dx_i^2)].
167+
--
168+
-- Assuming the linearity of module outputs (y_1, ..., y_m)
169+
-- with respect to module inputs and parameters (x_1, ..., x_n),
170+
-- we have (d^2y_k / dx_i^2) = 0,
171+
-- and the expression finally becomes:
172+
--
173+
-- d^2F / dx_i = \sum_k (dy_k / dx_i)^2.
174+
--
175+
-- The first derivatives (dy_k / dx_i) are computed by normal backpropagation,
176+
-- using updateGradInput() and accGradParameters().
177+
178+
local gradParam = module[gradParamName]
179+
180+
local diagHessian = gradParam.new():resize(gradParam:nElement()):zero()
181+
182+
module:forward(input)
183+
local gradOutput = module.output.new():resizeAs(module.output)
184+
local gradOutput1D = gradOutput:view(gradOutput:nElement())
185+
186+
for i=1,gradOutput:nElement() do
187+
gradOutput1D:zero()
188+
gradOutput1D[i] = 1
189+
module.gradWeight:zero()
190+
module.gradBias:zero()
191+
module:updateGradInput(input, gradOutput)
192+
module:accGradParameters(input, gradOutput)
193+
diagHessian:addcmul(gradParam, gradParam)
194+
end
195+
196+
return diagHessian
197+
end
198+
95199
function nn.Jacobian.forwardUpdate(module, input, param, perturbation)
96200
-- perturbation amount
97201
perturbation = perturbation or 1e-6
@@ -156,6 +260,33 @@ function nn.Jacobian.testJacobianUpdateParameters(module, input, param, minval,
156260
return error:abs():max()
157261
end
158262

263+
function nn.Jacobian.testDiagHessian(module, input, gradParamName, diagHessianParamName, minval, maxval)
264+
-- Compute the diagonal Hessian elements for the same function in two different ways,
265+
-- then compare the results and return the difference.
266+
267+
minval = minval or -2
268+
maxval = maxval or 2
269+
local inrange = maxval - minval
270+
input:copy(torch.rand(input:nElement()):mul(inrange):add(minval))
271+
module:initDiagHessianParameters()
272+
local h_bprop = nn.Jacobian.backwardDiagHessian(module, input, diagHessianParamName)
273+
local h_linearmodule = nn.Jacobian.linearModuleDiagHessian(module, input, gradParamName)
274+
local error = h_bprop - h_linearmodule
275+
return error:abs():max()
276+
end
277+
278+
function nn.Jacobian.testDiagHessianInput(module, input, minval, maxval)
279+
return nn.Jacobian.testDiagHessian(module, input, 'gradInput', 'diagHessianInput', minval, maxval)
280+
end
281+
282+
function nn.Jacobian.testDiagHessianWeight(module, input, minval, maxval)
283+
return nn.Jacobian.testDiagHessian(module, input, 'gradWeight', 'diagHessianWeight', minval, maxval)
284+
end
285+
286+
function nn.Jacobian.testDiagHessianBias(module, input, minval, maxval)
287+
return nn.Jacobian.testDiagHessian(module, input, 'gradBias', 'diagHessianBias', minval, maxval)
288+
end
289+
159290
function nn.Jacobian.testIO(module,input, minval, maxval)
160291
minval = minval or -2
161292
maxval = maxval or 2

Module.lua

+64-65
Original file line numberDiff line numberDiff line change
@@ -137,92 +137,91 @@ end
137137
function Module:reset()
138138
end
139139

140-
function Module:getParameters()
141-
-- get parameters
142-
local parameters,gradParameters = self:parameters()
143-
140+
-- this function flattens arbitrary lists of parameters,
141+
-- even complex shared ones
142+
function Module.flatten(parameters)
144143
local function storageInSet(set, storage)
145144
local storageAndOffset = set[torch.pointer(storage)]
146145
if storageAndOffset == nil then
147-
return nil
146+
return nil
148147
end
149148
local _, offset = table.unpack(storageAndOffset)
150149
return offset
151150
end
152151

153-
-- this function flattens arbitrary lists of parameters,
154-
-- even complex shared ones
155-
local function flatten(parameters)
156-
if not parameters or #parameters == 0 then
157-
return torch.Tensor()
152+
if not parameters or #parameters == 0 then
153+
return torch.Tensor()
154+
end
155+
local Tensor = parameters[1].new
156+
local dtype = parameters[1]:type()
157+
158+
local storages = {}
159+
local nParameters = 0
160+
for k = 1,#parameters do
161+
if parameters[k]:type() ~= dtype then
162+
error("Inconsistent parameter types. " .. parameters[k]:type() ..
163+
" ~= " .. dtype)
158164
end
159-
local Tensor = parameters[1].new
160-
local dtype = parameters[1]:type()
161-
162-
local storages = {}
163-
local nParameters = 0
164-
for k = 1,#parameters do
165-
if parameters[k]:type() ~= dtype then
166-
error("Inconsistent parameter types. " .. parameters[k]:type() ..
167-
" ~= " .. dtype)
168-
end
169-
local storage = parameters[k]:storage()
170-
if not storageInSet(storages, storage) then
171-
storages[torch.pointer(storage)] = {storage, nParameters}
172-
nParameters = nParameters + storage:size()
173-
end
165+
local storage = parameters[k]:storage()
166+
if not storageInSet(storages, storage) then
167+
storages[torch.pointer(storage)] = {storage, nParameters}
168+
nParameters = nParameters + storage:size()
174169
end
170+
end
175171

176-
local flatParameters = Tensor(nParameters):fill(1)
177-
local flatStorage = flatParameters:storage()
172+
local flatParameters = Tensor(nParameters):fill(1)
173+
local flatStorage = flatParameters:storage()
178174

179-
for k = 1,#parameters do
180-
local storageOffset = storageInSet(storages, parameters[k]:storage())
181-
parameters[k]:set(flatStorage,
182-
storageOffset + parameters[k]:storageOffset(),
183-
parameters[k]:size(),
184-
parameters[k]:stride())
185-
parameters[k]:zero()
186-
end
175+
for k = 1,#parameters do
176+
local storageOffset = storageInSet(storages, parameters[k]:storage())
177+
parameters[k]:set(flatStorage,
178+
storageOffset + parameters[k]:storageOffset(),
179+
parameters[k]:size(),
180+
parameters[k]:stride())
181+
parameters[k]:zero()
182+
end
187183

188-
local maskParameters = flatParameters:float():clone()
189-
local cumSumOfHoles = flatParameters:float():cumsum(1)
190-
local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
191-
local flatUsedParameters = Tensor(nUsedParameters)
192-
local flatUsedStorage = flatUsedParameters:storage()
193-
194-
for k = 1,#parameters do
195-
local offset = cumSumOfHoles[parameters[k]:storageOffset()]
196-
parameters[k]:set(flatUsedStorage,
197-
parameters[k]:storageOffset() - offset,
198-
parameters[k]:size(),
199-
parameters[k]:stride())
200-
end
184+
local maskParameters = flatParameters:float():clone()
185+
local cumSumOfHoles = flatParameters:float():cumsum(1)
186+
local nUsedParameters = nParameters - cumSumOfHoles[#cumSumOfHoles]
187+
local flatUsedParameters = Tensor(nUsedParameters)
188+
local flatUsedStorage = flatUsedParameters:storage()
189+
190+
for k = 1,#parameters do
191+
local offset = cumSumOfHoles[parameters[k]:storageOffset()]
192+
parameters[k]:set(flatUsedStorage,
193+
parameters[k]:storageOffset() - offset,
194+
parameters[k]:size(),
195+
parameters[k]:stride())
196+
end
201197

202-
for _, storageAndOffset in pairs(storages) do
203-
local k, v = table.unpack(storageAndOffset)
204-
flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
205-
end
198+
for _, storageAndOffset in pairs(storages) do
199+
local k, v = table.unpack(storageAndOffset)
200+
flatParameters[{{v+1,v+k:size()}}]:copy(Tensor():set(k))
201+
end
206202

207-
if cumSumOfHoles:sum() == 0 then
208-
flatUsedParameters:copy(flatParameters)
209-
else
210-
local counter = 0
211-
for k = 1,flatParameters:nElement() do
212-
if maskParameters[k] == 0 then
213-
counter = counter + 1
214-
flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]]
215-
end
203+
if cumSumOfHoles:sum() == 0 then
204+
flatUsedParameters:copy(flatParameters)
205+
else
206+
local counter = 0
207+
for k = 1,flatParameters:nElement() do
208+
if maskParameters[k] == 0 then
209+
counter = counter + 1
210+
flatUsedParameters[counter] = flatParameters[counter+cumSumOfHoles[k]]
216211
end
217-
assert (counter == nUsedParameters)
218212
end
219-
return flatUsedParameters
213+
assert (counter == nUsedParameters)
220214
end
215+
return flatUsedParameters
216+
end
221217

218+
function Module:getParameters()
219+
-- get parameters
220+
local parameters,gradParameters = self:parameters()
222221
-- flatten parameters and gradients
223-
local flatParameters = flatten(parameters)
222+
local flatParameters = Module.flatten(parameters)
224223
collectgarbage()
225-
local flatGradParameters = flatten(gradParameters)
224+
local flatGradParameters = Module.flatten(gradParameters)
226225
collectgarbage()
227226

228227
-- return new flat vector that contains all discrete parameters

0 commit comments

Comments
 (0)