From 7bef09b6832a368a4dc422bfecc427830b56550c Mon Sep 17 00:00:00 2001 From: Paul Novotny Date: Sun, 21 Apr 2024 13:28:53 +0000 Subject: [PATCH] Allow BatchNorm on CUDA with track_stats=False Enables BatchNorm without track_stats in training and test modes. Also, unit tests are added to ensure the CUDA implementation matches the CPU implementation. --- ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl | 2 +- test/ext_cuda/layers.jl | 17 +++++++++++++---- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl b/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl index 1f808709c2..b354e50b5d 100644 --- a/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl +++ b/ext/FluxCUDAcuDNNExt/FluxCUDAcuDNNExt.jl @@ -24,11 +24,11 @@ function (BN::Flux.BatchNorm)(x::Union{CuArray{T,2},CuArray{T,4},CuArray{T,5}}, cache=nothing) where T<:Union{Float32, Float64} @assert BN.affine "BatchNorm: only affine=true supported on gpu" - @assert BN.track_stats "BatchNorm: only track_stats=true supported on gpu" @assert length(BN.β) == size(x, ndims(x)-1) "BatchNorm: input has wrong number of channels" return BN.λ.(NNlib.batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache=cache, alpha=1, beta=0, eps=BN.ϵ, + track_stats=BN.track_stats, training=Flux._isactive(BN, x))) end diff --git a/test/ext_cuda/layers.jl b/test/ext_cuda/layers.jl index e59ff35aa4..63bcc8b526 100644 --- a/test/ext_cuda/layers.jl +++ b/test/ext_cuda/layers.jl @@ -17,7 +17,7 @@ const ACTIVATIONS = [identity, relu, tanh, sigmoid, exp, softplus, elu, selu] -function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; test_cpu = true) +function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; test_cpu = true, test_mode = false) isnothing(x_cpu) && error("Missing input to test the layers against.") @testset "$name GPU grad tests" begin for layer in layers @@ -25,12 +25,17 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te # compute output and grad of parameters l_cpu = layer(args...) + l_gpu = l_cpu |> gpu + if test_mode + testmode!(l_cpu) + testmode!(l_gpu) + end + ps_cpu = Flux.params(l_cpu) y_cpu, back_cpu = pullback(() -> sum(l_cpu(x_cpu)), ps_cpu) gs_cpu = back_cpu(1f0) x_gpu = gpu(x_cpu) - l_gpu = l_cpu |> gpu ps_gpu = Flux.params(l_gpu) if typeof(l_gpu) <: BROKEN_LAYERS @@ -78,6 +83,7 @@ function gpu_gradtest(name::String, layers::Vector, x_cpu = nothing, args...; te end # Just to give testset in gpu_gradtest meaningful labels +BatchNormNoTrackStats(args...) = BatchNorm(args...; track_stats = false) ConvNoBias(args...) = Conv(args...; bias = false) ConvTransposeNoBias(args...) = ConvTranspose(args...; bias = false) CrossCorNoBias(args...) = CrossCor(args...; bias = false) @@ -96,9 +102,12 @@ for act in ACTIVATIONS groupedconv = [GroupedConv, GroupedConvTranspose] gpu_gradtest("GroupedConvolution with $act", groupedconv, rand(Float32, 28, 28, 100, 2), (3,3), 100 => 25, act, test_cpu = true) - batch_norm = [BatchNorm] + batch_norm = [BatchNorm, BatchNormNoTrackStats] gpu_gradtest("BatchNorm 1 with $act", batch_norm, rand(Float32, 28,28,3,4), 3, act, test_cpu = false) #TODO fix errors - gpu_gradtest("BatchNorm 2 with $act", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = false) + gpu_gradtest("BatchNorm 2 with $act", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = true) + + batch_norm = [BatchNormNoTrackStats] + gpu_gradtest("BatchNorm 3 with $act (test mode)", batch_norm, rand(Float32, 5,4), 5, act, test_cpu = true, test_mode = true) instancenorm = [InstanceNorm] gpu_gradtest("InstanceNorm with $act", instancenorm, r, 1, act, test_cpu = false)