Skip to content

Commit

Permalink
Allow BatchNorm on CUDA with track_stats=False
Browse files Browse the repository at this point in the history
Enables BatchNorm without track_stats in training and test modes. Also,
unit tests are added to ensure the CUDA implementation matches the CPU
implementation.
  • Loading branch information
paulnovo committed Apr 21, 2024
1 parent 8654721 commit 7bef09b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 5 deletions.
2 changes: 1 addition & 1 deletion ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,11 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}},
cache=nothing) where T<:Union{Float32, Float64}

@assert BN.affine "BatchNorm: only affine=true supported on gpu"
@assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu"
@assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wrong number of channels"

return BN.λ.(NNlib.batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum;
cache=cache, alpha=1, beta=0, eps=BN.ϵ,
track_stats=BN.track_stats,
training=Flux._isactive(BN, x)))
end

Expand Down
17 changes: 13 additions & 4 deletions test/ext_cuda/layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,25 @@ const ACTIVATIONS = [identity, relu, tanh,
sigmoid, exp, softplus,
elu, selu]

function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; test_cpu = true)
function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; test_cpu = true, test_mode = false)
isnothing(x_cpu) && error("Missing input to test the layers against.")
@testset "$name GPU grad tests" begin
for layer in layers
@testset "$layer Layer GPU grad test" begin

# compute output and grad of parameters
l_cpu = layer(args...)
l_gpu = l_cpu |> gpu
if test_mode
testmode!(l_cpu)
testmode!(l_gpu)
end

ps_cpu = Flux.params(l_cpu)
y_cpu, back_cpu = pullback(() -> sum(l_cpu(x_cpu)), ps_cpu)
gs_cpu = back_cpu(1f0)

x_gpu = gpu(x_cpu)
l_gpu = l_cpu |> gpu
ps_gpu = Flux.params(l_gpu)

if typeof(l_gpu) <: BROKEN_LAYERS
Expand Down Expand Up @@ -78,6 +83,7 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te
end

# Just to give testset in gpu_gradtest meaningful labels
BatchNormNoTrackStats(args...) = BatchNorm(args...; track_stats = false)
ConvNoBias(args...) = Conv(args...; bias = false)
ConvTransposeNoBias(args...) = ConvTranspose(args...; bias = false)
CrossCorNoBias(args...) = CrossCor(args...; bias = false)
Expand All @@ -96,9 +102,12 @@ for act in ACTIVATIONS
groupedconv = [GroupedConv, GroupedConvTranspose]
gpu_gradtest("GroupedConvolution with $act", groupedconv, rand(Float32, 28, 28, 100, 2), (3,3), 100 => 25, act, test_cpu = true)

batch_norm = [BatchNorm]
batch_norm = [BatchNorm, BatchNormNoTrackStats]
gpu_gradtest("BatchNorm 1 with $act", batch_norm, rand(Float32, 28,28,3,4), 3, act, test_cpu = false) #TODO fix errors
gpu_gradtest("BatchNorm 2 with $act", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = false)
gpu_gradtest("BatchNorm 2 with $act", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = true)

batch_norm = [BatchNormNoTrackStats]
gpu_gradtest("BatchNorm 3 with $act (test mode)", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = true, test_mode = true)

instancenorm = [InstanceNorm]
gpu_gradtest("InstanceNorm with $act", instancenorm, r, 1, act, test_cpu = false)
Expand Down

0 comments on commit 7bef09b

Please sign in to comment.