Skip to content

Commit

Permalink
giant post-rebase fixup after everything was moved around... all earl…
Browse files Browse the repository at this point in the history
…ier commits are a mess now, probably
  • Loading branch information
mcabbott committed Nov 23, 2024
1 parent 48dc717 commit 037d6ef
Show file tree
Hide file tree
Showing 8 changed files with 46 additions and 192 deletions.
14 changes: 13 additions & 1 deletion docs/src/reference/training/enzyme.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ julia> grads_f = Flux.gradient((m,x,y) -> sum(abs2, m(x) .- y), dup_model, Const
-0.0014538406], σ = nothing), nothing),), nothing, nothing)
```

The gradient returned here is also stored within `dup_model`, it shares the same arrays.
The gradient returned here is also stored within `dup_model`.
Both share the same arrays -- what is returned is not a copy, just a view of the same memory (wrapped in `NamedTuple`s instead of `struct`s).
They will all be set to zero when you call `gradient` again, then replaced with the new values.
Alternatively, `gradient(f, args...; zero=false)` will add the new gradient to what's already stored.

Expand Down Expand Up @@ -81,8 +82,19 @@ julia> Flux.train!((m,x,y) -> sum(abs2, m(x) .- y), dup_model, [(x1, y1)], opt_s

## Listing

Flux functions:

```@docs
Flux.gradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...)
Flux.withgradient(f, args::Union{Flux.EnzymeCore.Const, Flux.EnzymeCore.Duplicated}...)
Flux.train!(loss, model::Flux.EnzymeCore.Duplicated, data, opt)
```

EnzymeCore types:

```@docs
Flux.EnzymeCore.Duplicated
Flux.EnzymeCore.Const
```

Enzyme.jl has [its own extensive documentation](https://enzymead.github.io/Enzyme.jl/stable/).
4 changes: 1 addition & 3 deletions ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,7 @@ module FluxEnzymeExt

using Flux
using Flux: _make_zero!

import Flux.Train: _enzyme_train!, _rule_to_state, _grad_or_nothing
# import Flux.Optimise
import Flux.Train: _enzyme_train!

import Optimisers
import Functors
Expand Down
49 changes: 7 additions & 42 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using MacroTools: @forward

@reexport using NNlib
using MLUtils
using Adapt, Functors, OneHotArrays

using Optimisers: Optimisers, destructure, freeze!, thaw!, adjust!, trainables, update!
import Optimisers: trainable
Expand Down Expand Up @@ -60,43 +61,9 @@ export Chain, Dense, Embedding, EmbeddingBag,
destructure, freeze!, thaw!, adjust!, trainables, update!, trainable,
# from Functors.jl
functor, @functor, KeyPath, haskeypath, getkeypath,
# from Optimise/Train/Optimisers.jl
setup, update!, destructure, freeze!, adjust!, params, trainable, trainables
))

# Pirate error to catch a common mistake.
Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.")

include("layers/show.jl")
include("layers/macro.jl")

include("layers/stateless.jl")
include("layers/basic.jl")
include("layers/conv.jl")
include("layers/recurrent.jl")
include("layers/normalise.jl")
include("layers/upsample.jl")
include("layers/attention.jl")

include("loading.jl")

include("outputsize.jl")
export @autosize

include("deprecations.jl")

include("losses/Losses.jl")
using .Losses

include("devices.jl")
export get_device, gpu_backend!

# Distributed Training
include("distributed/backend.jl")
include("distributed/public_api.jl")
export MPIBackend, NCCLBackend, DistributedUtils

@compat(public, (
# from Train/Optimisers.jl
setup, update!, destructure, freeze!, adjust!, params, trainable, trainables,
withgradient,
# init
glorot_uniform,
glorot_normal,
Expand Down Expand Up @@ -128,15 +95,13 @@ export MPIBackend, NCCLBackend, DistributedUtils
tversky_loss,
))

include("gradient.jl")
export gradient

include("train.jl")
using .Train
using .Train: setup

include("gradient.jl")
export gradient
@compat(public, (withgradient,))

using Adapt, Functors, OneHotArrays
include("utils.jl")
include("functor.jl")

Expand Down
148 changes: 14 additions & 134 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,125 +18,6 @@ GRUv3Cell(in::Integer, out::Integer; kw...) = GRUv3Cell(in => out; kw...)

#### v0.14 deprecations ###########################

<<<<<<< HEAD
=======
# Valid methods in Train, new explict style, are:
train!(loss, model, data, opt) # preferred
train!(loss, model, data, opt::Optimisers.AbstractRule) # if you forget setup

# Provide friendly errors for what happens if you mix these up:
=#
import .Optimise: train!

train!(loss, ps::Params, data, opt; cb=nothing) = error(
"""can't mix implict Params with explict state!
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
But better to use the new explicit style, in which `m` itself is the 2nd argument.
""")

train!(loss, ps::Params, data, opt::Optimisers.AbstractRule; cb=nothing) = error(
"""can't mix implict Params with explict rule from Optimisers.jl
To use `Flux.params(m)` in `train!`, the 4th argument must be from the old `Flux.Optimise` sub-module.
But better to use the new explicit style, in which `m` itself is the 2nd argument.
""")

train!(loss, model, data, opt::Optimise.AbstractOptimiser; cb=nothing) =
train!(loss, model, data, __old_to_new(opt); cb)

# Next, to use the new `setup` with the still-exported old-style `Adam` etc:
import .Train: setup
setup(rule::Optimise.AbstractOptimiser, model) = setup(__old_to_new(rule), model)
# ... and allow accidental use of `Optimisers.setup` to do the same:
Optimisers.setup(rule::Optimise.AbstractOptimiser, model) = setup(__old_to_new(rule), model)


function __old_to_new(rule)
Base.depwarn("""Optimisers from Flux.Optimise module are deprecated.
Use optimisers from Optimisers.jl instead.""", :__old_to_new)
return _old_to_new(rule)
end

for T in [:Descent, :Adam, :Momentum, :Nesterov,
:AdaGrad, :AdaMax, :AdaDelta, :AMSGrad, :NAdam, :RAdam, :OAdam, :AdaBelief,
# :InvDecay, :ExpDecay,
:SignDecay,
]
@eval function _old_to_new(rule::Optimise.$T)
args = map(f -> getfield(rule, f), fieldnames(Optimisers.$T))
Optimisers.$T(args...)
end
end

_old_to_new(rule::Optimise.Optimiser) = Optimisers.OptimiserChain(map(_old_to_new, rule.os)...)
# const OptimiserChain = Optimise.Optimiser # lets you use new name with implicit params too.
const Optimiser = Optimisers.OptimiserChain
_old_to_new(rule::Optimise.WeightDecay) = Optimisers.WeightDecay(rule.wd) # called lambda now
_old_to_new(rule::Optimise.ClipNorm) = Optimisers.ClipNorm(rule.thresh) # called omega, and there are more fields
_old_to_new(rule::Optimise.ClipValue) = Optimisers.ClipGrad(rule.thresh) # called delta now, and struct name differs
# const ClipGrad = Optimise.ClipValue
const ClipValue = Optimisers.ClipGrad
_old_to_new(rule::Optimise.RMSProp) = Optimisers.RMSProp(rule.eta, rule.rho, rule.epsilon) # RMSProp has no field centred

_old_to_new(rule) = error("Flux.setup does not know how to translate this old-style implicit rule to a new-style Optimisers.jl explicit rule")

# This allows you to mix and match, like Flux.setup(OptimiserChain(Optimisers.SignDecay(), Flux.Descent()), [1,2,3.])
Optimisers.OptimiserChain(rules::Union{Optimisers.AbstractRule, Optimise.AbstractOptimiser}...) =
Optimisers.OptimiserChain(map(_old_to_new, rules))
_old_to_new(rule::Optimisers.AbstractRule) = rule

# Since `update!` should be called in a loop, it makes less sense to call `setup` for you if you forgot.
# But let's make sure that such uses give a helpful error:
import .Optimise: update!

function update!(opt::Optimise.AbstractOptimiser, model, grad)
# This error method requires narrowing the main worker method of Flux.Optimise
# to accept only arrays. Remove if this causes problems!
# update!(opt::Flux.Optimise.AbstractOptimiser, x::AbstractArray, x̄)
error("""Invalid input to `update!`.
* For the implicit style, this needs `update!(::AbstractOptimiser, ::Params, ::Grads)`
* For the explicit style, `update!(state, model, grad)` needs `state = Flux.setup(opt, model)`.
""")
end

# TODO this friendly error should go in Optimisers.jl.
# remove after https://github.com/FluxML/Optimisers.jl/pull/181
function update!(opt::Optimisers.AbstractRule, model, grad)
error("""Invalid input to `update!`.
`update!(state, model, grad)` needs `state = Flux.setup(opt, model)`.
""")
end
function update!(opt::Optimisers.AbstractRule, model::Chain, grad::Tuple)
error("""Invalid input to `update!`.
`update!(state, model, grad)` needs `state = Flux.setup(opt, model)`.
""")
end

# An easy error to make is to pass result of explicit gradient(...), not gradient(...)[1]
# Can't catch every case, but can catch many simple Flux models:

function update!(opt, model::Chain, grads::Tuple)
# Zygote will make a NamedTuple{(:layers,)} for the gradient of Chain, Diffractor a Tangent
@warn """explicit `update!(opt, model, grad)` wants the gradient for the model alone,
not the whole tuple from `gradient(m -> loss(m, x, y), model)`. You probably want `grads[1]`."""
update!(opt, model, grads[1])
end

function update!(opt::Optimise.AbstractOptimiser, model::Chain, grads::Tuple) # ambiguity
update!(opt, model, grads[1]) # calls error case "Invalid input" just above
end

# One more easy error to catch is using explicit gradient with `params(m)`:

function update!(opt::Optimise.AbstractOptimiser, ::Params, grads::Union{Tuple, NamedTuple})
error("""can't mix implicit Params with explicit gradients!
* For the implicit style, this needs `update(::AbstractOptimiser, ::Params, ::Grads)` with implicit gradient.
* For the explicit style, `update(state, model, grad)` needs the model itself, and `state = Flux.setup(opt, model)`.
""")
end


# v0.14 deprecations
>>>>>>> 9576dba8 (add more Duplicated methods)
@deprecate default_rng_value() Random.default_rng()


Expand Down Expand Up @@ -184,21 +65,6 @@ const FluxMetalAdaptor = MetalDevice
######## v0.15 deprecations #########################

# Enable these when 0.16 is released, and delete const ClipGrad = Optimise.ClipValue etc:
function gradient(f, p::Zygote.Params)
Base.depwarn("""Implicit gradients such as `gradient(f, ::Params)` are deprecated in Flux!
Please see the docs for new explicit form.""", :gradient; force=true)
Zygote.gradient(f, p)
end

function withgradient(f, p::Zygote.Params)
Base.depwarn("""Implicit gradients such as `withgradient(f, ::Params)` are deprecated in Flux!
Please see the docs for new explicit form.""", :withgradient; force=true)
Zygote.withgradient(f, p)
end

# v0.15 deprecations

# Enable these when 0.15 is released, and delete const ClipGrad = Optimise.ClipValue etc:
# Base.@deprecate_binding Optimiser OptimiserChain
# Base.@deprecate_binding ClipValue ClipGrad

Expand Down Expand Up @@ -255,8 +121,22 @@ function Optimisers.update!(opt::Optimisers.AbstractRule, model, grad)
`update!(state, model, grad)` needs `state = Flux.setup(opt, model)`.
""")
end

# This exists to solve an ambiguity between the method above & one in layers/basic.jl
function Optimisers.update!(opt::Optimisers.AbstractRule, model::Chain, grad::Tuple)
error("""Invalid input to `update!`.
`update!(state, model, grad)` needs `state = Flux.setup(opt, model)`.
""")
end

# From 0.15, Flux.gradient is not Zygote.gradient, but we can add a deprecation path:
function gradient(f, p::Zygote.Params)
Base.depwarn("""Implicit gradients such as `gradient(f, ::Params)` are deprecated in Flux!

Check warning on line 134 in src/deprecations.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecations.jl#L133-L134

Added lines #L133 - L134 were not covered by tests
Please see the docs for new explicit form.""", :gradient; force=true)
Zygote.gradient(f, p)

Check warning on line 136 in src/deprecations.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecations.jl#L136

Added line #L136 was not covered by tests
end
function withgradient(f, p::Zygote.Params)
Base.depwarn("""Implicit gradients such as `withgradient(f, ::Params)` are deprecated in Flux!

Check warning on line 139 in src/deprecations.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecations.jl#L138-L139

Added lines #L138 - L139 were not covered by tests
Please see the docs for new explicit form.""", :withgradient; force=true)
Zygote.withgradient(f, p)

Check warning on line 141 in src/deprecations.jl

View check run for this annotation

Codecov / codecov/patch

src/deprecations.jl#L141

Added line #L141 was not covered by tests
end
4 changes: 3 additions & 1 deletion src/losses/Losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ using Statistics
using Zygote
using Zygote: @adjoint
using ChainRulesCore
using ..Flux: ofeltype, epseltype
# using ..Flux: ofeltype, epseltype
ofeltype(x, y) = convert(float(eltype(x)), y)
epseltype(x) = eps(float(eltype(x)))
using NNlib: logsoftmax, logσ, ctc_loss, ctc_alpha, ∇ctc_loss
import Base.Broadcast: broadcasted

Expand Down
3 changes: 0 additions & 3 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
<<<<<<< HEAD
=======
using ProgressLogging: @progress, @withprogress, @logprogress
import Zygote: Params, gradient, withgradient

Expand All @@ -21,7 +19,6 @@ The gradient could be mutated as well.
This method for implicit `Params` (and `AbstractOptimiser`) will be removed from Flux 0.15.
The explicit method `update!(opt, model, grad)` from Optimisers.jl will remain.
"""
>>>>>>> 1466ba36 (let Flux own the function update! to avoid piracy)
function update!(opt::AbstractOptimiser, x::AbstractArray, x̄)
x̄r = copyto!(similar(x̄), x̄) # Flux.Optimise assumes it can mutate the gradient. This is not
# safe due to aliasing, nor guaranteed to be possible, e.g. Fill.
Expand Down
10 changes: 5 additions & 5 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using ..Flux: Flux
using ProgressLogging: @progress, @withprogress, @logprogress
using Zygote: Zygote

import ..Flux.Optimise: train!, update!, Optimise # during 0.13, we add methods to the old functions
# import ..Flux.Optimise: train!, update!, Optimise # during 0.13, we add methods to the old functions

export setup, train!

Expand Down Expand Up @@ -163,10 +163,10 @@ train!(loss, model::Duplicated, data, opt; cb = nothing) = _enzyme_train!(loss,
# FluxEnzymeExt defines more specific _enzyme_train!(loss, model::Duplicated, data, opt; cb)
_enzyme_train!(loss, model, data, opt; cb = nothing) = throw(ArgumentError("The method `train!(loss, Duplicated(model), data, opt_state)` is only available when Enzyme.jl is loaded"))

Check warning on line 164 in src/train.jl

View check run for this annotation

Codecov / codecov/patch

src/train.jl#L164

Added line #L164 was not covered by tests

# Following src/deprecations.jl
function train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing)
train!(loss, model, data, _old_to_new(opt); cb)
end
# # Following src/deprecations.jl
# function train!(loss, model::Duplicated, data, opt::Optimise.AbstractOptimiser; cb=nothing)
# train!(loss, model, data, _old_to_new(opt); cb)
# end

# This method let you use Optimisers.Descent() without setup, when there is no state
function train!(loss, model::Duplicated, data, rule::Optimisers.AbstractRule; cb=nothing)
Expand Down
6 changes: 3 additions & 3 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using Test
using Flux

using Enzyme: Enzyme, make_zero, Active, Duplicated, ReverseWithPrimal
using Enzyme: Enzyme, make_zero, Active, Duplicated, Const, ReverseWithPrimal

using Functors
using FiniteDifferences
Expand Down Expand Up @@ -112,8 +112,8 @@ end
]

for (model, x, name) in models_xs
@testset "check grad $name" begin
println("testing $name")
@testset "Enzyme grad check $name" begin
println("testing $name with Enzyme")
test_enzyme_grad(loss, model, x)
end
end
Expand Down

0 comments on commit 037d6ef

Please sign in to comment.