From 5f63dbff5d638cb6a61dcf007d793335f66ccbd5 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 5 Jan 2023 14:29:01 -0500 Subject: [PATCH] Add `within_gradient` (#434) * add within_gradient * add ForwardDiff method * docs * use in softmax too --- Project.toml | 3 ++- docs/src/reference.md | 1 + src/NNlib.jl | 6 +++++ src/softmax.jl | 5 +---- src/utils.jl | 52 +++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 1 + test/utils.jl | 6 +++++ 7 files changed, 69 insertions(+), 5 deletions(-) diff --git a/Project.toml b/Project.toml index 5cca81f44..75a2bb504 100644 --- a/Project.toml +++ b/Project.toml @@ -22,6 +22,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" @@ -31,4 +32,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["ChainRulesTestUtils", "CUDA", "Documenter", "FiniteDifferences", "Logging", "NNlibCUDA", "Random", "StableRNGs", "Test", "UnicodePlots", "Zygote"] +test = ["ChainRulesTestUtils", "CUDA", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "NNlibCUDA", "Random", "StableRNGs", "Test", "UnicodePlots", "Zygote"] diff --git a/docs/src/reference.md b/docs/src/reference.md index a034c92c5..f4ae43370 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -132,4 +132,5 @@ ctc_loss ```@docs logsumexp NNlib.glu +NNlib.within_gradient ``` diff --git a/src/NNlib.jl b/src/NNlib.jl index acca75299..6c897fa20 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -87,6 +87,12 @@ export upsample_nearest, ∇upsample_nearest, include("gather.jl") include("scatter.jl") include("utils.jl") +@init @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin + using .ForwardDiff + within_gradient(x::ForwardDiff.Dual) = true + within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true +end + include("sampling.jl") include("functions.jl") diff --git a/src/softmax.jl b/src/softmax.jl index f815ac976..182f2fb93 100644 --- a/src/softmax.jl +++ b/src/softmax.jl @@ -69,7 +69,7 @@ function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T} end function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S} - dx = if within_grad() + dx = if within_gradient(y) tmp = dy .* y tmp .- y .* sum(tmp; dims) else @@ -88,9 +88,6 @@ function rrule(::typeof(softmax), x; dims = 1) return y, softmax_pullback end -within_grad() = false -rrule(::typeof(within_grad)) = true, _ -> (NoTangent(),) - fast_maximum(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf)) """ diff --git a/src/utils.jl b/src/utils.jl index cd1b9f03b..6d16b1cb1 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,3 +1,55 @@ +""" + within_gradient(x) --> Bool + +Returns `false` except when used inside a `gradient` call, when it returns `true`. +Useful for Flux regularisation layers which behave differently during training and inference. + +This should work with any ChainRules-based differentiation package, in which case `x` is ignored. +But Tracker.jl overloads `with_gradient(x::TrackedArray)`, thus for widest use you should +pass it an array whose gradient is of interest. +There is also an overload for ForwardDiff.jl's `Dual` types (and arrays of them). + +# Examples +``` +julia> using ForwardDiff, Zygote, NNlib + +julia> f_good(x) = if NNlib.within_gradient(x) + @show 10x + else + x + end; + +julia> Zygote.withgradient(f_good, 1.0) +10x = 10.0 +(val = 10.0, grad = (10.0,)) + +julia> ForwardDiff.derivative(f_good, 1.0) +10x = Dual{ForwardDiff.Tag{typeof(f_good), Float64}}(10.0,10.0) +10.0 + +julia> f_bad(x, y) = if any(NNlib.within_gradient, (x, y)) + @show x * y + else + x / y + end; + +julia> Zygote.withgradient(f_bad, 2.0, 3.0) +(val = 0.6666666666666666, grad = (0.3333333333333333, -0.2222222222222222)) + +julia> ForwardDiff.derivative(x -> f_bad(x, 3.0), 2.0) +x * y = Dual{ForwardDiff.Tag{var"#9#10", Float64}}(6.0,3.0) +3.0 +``` + +What goes wrong in `f_bad` is that Zygote knows `any` to be non-differentiable, +and thus completely ignores its contents. This is not a perfect mechanism, +and the only style recommended is precisely that of `f_good` above. +""" +within_gradient(x) = false + +ChainRulesCore.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), NoTangent()) + + """ safe_div(x, y) diff --git a/test/runtests.jl b/test/runtests.jl index 16084b4d2..357b96f4a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using NNlib, Test, Statistics, Random using ChainRulesCore, ChainRulesTestUtils using Base.Broadcast: broadcasted import FiniteDifferences +import ForwardDiff import Zygote using Zygote: gradient using StableRNGs diff --git a/test/utils.jl b/test/utils.jl index a5264dc5a..131cec07f 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,3 +1,9 @@ +@testset "within_gradient" begin + @test NNlib.within_gradient([1.0]) === false + @test gradient(x -> NNlib.within_gradient(x) * x, 2.0) == (1.0,) + @test NNlib.within_gradient([ForwardDiff.Dual(1.0, 2)]) === true +end + @testset "maximum_dims" begin ind1 = [1,2,3,4,5,6] @test NNlib.maximum_dims(ind1) == (6,)