Skip to content

Commit

Permalink
Add within_gradient (#434)
Browse files Browse the repository at this point in the history
* add within_gradient

* add ForwardDiff method

* docs

* use in softmax too
  • Loading branch information
mcabbott authored Jan 5, 2023
1 parent 2bae421 commit 5f63dbf
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 5 deletions.
3 changes: 2 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -31,4 +32,4 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["ChainRulesTestUtils", "CUDA", "Documenter", "FiniteDifferences", "Logging", "NNlibCUDA", "Random", "StableRNGs", "Test", "UnicodePlots", "Zygote"]
test = ["ChainRulesTestUtils", "CUDA", "Documenter", "FiniteDifferences", "ForwardDiff", "Logging", "NNlibCUDA", "Random", "StableRNGs", "Test", "UnicodePlots", "Zygote"]
1 change: 1 addition & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,5 @@ ctc_loss
```@docs
logsumexp
NNlib.glu
NNlib.within_gradient
```
6 changes: 6 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ export upsample_nearest, ∇upsample_nearest,
include("gather.jl")
include("scatter.jl")
include("utils.jl")
@init @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin
using .ForwardDiff
within_gradient(x::ForwardDiff.Dual) = true
within_gradient(x::AbstractArray{<:ForwardDiff.Dual}) = true
end

include("sampling.jl")
include("functions.jl")

Expand Down
5 changes: 1 addition & 4 deletions src/softmax.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function softmax!(out::AbstractArray{T}, x::AbstractArray; dims = 1) where {T}
end

function ∇softmax_data(dy::AbstractArray{T}, y::AbstractArray{S}; dims = 1) where {T,S}
dx = if within_grad()
dx = if within_gradient(y)
tmp = dy .* y
tmp .- y .* sum(tmp; dims)
else
Expand All @@ -88,9 +88,6 @@ function rrule(::typeof(softmax), x; dims = 1)
return y, softmax_pullback
end

within_grad() = false
rrule(::typeof(within_grad)) = true, _ -> (NoTangent(),)

fast_maximum(x::AbstractArray{T}; dims) where {T} = @fastmath reduce(max, x; dims, init = float(T)(-Inf))

"""
Expand Down
52 changes: 52 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,55 @@
"""
within_gradient(x) --> Bool
Returns `false` except when used inside a `gradient` call, when it returns `true`.
Useful for Flux regularisation layers which behave differently during training and inference.
This should work with any ChainRules-based differentiation package, in which case `x` is ignored.
But Tracker.jl overloads `with_gradient(x::TrackedArray)`, thus for widest use you should
pass it an array whose gradient is of interest.
There is also an overload for ForwardDiff.jl's `Dual` types (and arrays of them).
# Examples
```
julia> using ForwardDiff, Zygote, NNlib
julia> f_good(x) = if NNlib.within_gradient(x)
@show 10x
else
x
end;
julia> Zygote.withgradient(f_good, 1.0)
10x = 10.0
(val = 10.0, grad = (10.0,))
julia> ForwardDiff.derivative(f_good, 1.0)
10x = Dual{ForwardDiff.Tag{typeof(f_good), Float64}}(10.0,10.0)
10.0
julia> f_bad(x, y) = if any(NNlib.within_gradient, (x, y))
@show x * y
else
x / y
end;
julia> Zygote.withgradient(f_bad, 2.0, 3.0)
(val = 0.6666666666666666, grad = (0.3333333333333333, -0.2222222222222222))
julia> ForwardDiff.derivative(x -> f_bad(x, 3.0), 2.0)
x * y = Dual{ForwardDiff.Tag{var"#9#10", Float64}}(6.0,3.0)
3.0
```
What goes wrong in `f_bad` is that Zygote knows `any` to be non-differentiable,
and thus completely ignores its contents. This is not a perfect mechanism,
and the only style recommended is precisely that of `f_good` above.
"""
within_gradient(x) = false

ChainRulesCore.rrule(::typeof(within_gradient), x) = true, _ -> (NoTangent(), NoTangent())


"""
safe_div(x, y)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ using NNlib, Test, Statistics, Random
using ChainRulesCore, ChainRulesTestUtils
using Base.Broadcast: broadcasted
import FiniteDifferences
import ForwardDiff
import Zygote
using Zygote: gradient
using StableRNGs
Expand Down
6 changes: 6 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
@testset "within_gradient" begin
@test NNlib.within_gradient([1.0]) === false
@test gradient(x -> NNlib.within_gradient(x) * x, 2.0) == (1.0,)
@test NNlib.within_gradient([ForwardDiff.Dual(1.0, 2)]) === true
end

@testset "maximum_dims" begin
ind1 = [1,2,3,4,5,6]
@test NNlib.maximum_dims(ind1) == (6,)
Expand Down

2 comments on commit 5f63dbf

@mcabbott
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/75189

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.14 -m "<description of version>" 5f63dbff5d638cb6a61dcf007d793335f66ccbd5
git push origin v0.8.14

Please sign in to comment.