Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enzyme autodiff helpers #954

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions ext/LuxEnzymeExt/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,39 @@
module LuxEnzymeExt

using ADTypes: AutoEnzyme
using Enzyme: Enzyme, Active, Const, Duplicated
using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode
using ArgCheck: @argcheck
using ConcreteStructs: @concrete
using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated
using EnzymeCore: EnzymeCore
using Functors: fmap
using Setfield: @set!
using Static: False, True
using Setfield: @set!, @set
using Static: False, True, StaticBool

using Lux: Lux, Utils
using Lux.Training: TrainingBackendCache, TrainState
using MLDataDevices: isleaf

Lux.is_extension_loaded(::Val{:Enzyme}) = true

normalize_backend(::StaticBool, ad::AutoEnzyme) = ad
normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Forward)
normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Reverse)

annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f
annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f)

include("training.jl")

include("autodiff.jl")
include("batched_autodiff.jl")

@concrete struct OOPFunctionWrapper
f
end

function (f::OOPFunctionWrapper)(y, args...)
copyto!(y, f.f(args...))
return
end

end
37 changes: 37 additions & 0 deletions ext/LuxEnzymeExt/autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
function Lux.AutoDiffInternalImpl.jacobian_vector_product_impl(
f::F, ad::AutoEnzyme, x, u, p) where {F}
ad = normalize_backend(True(), ad)
@assert ADTypes.mode(ad) isa ForwardMode "JVPs are only supported in forward mode."
return only(
Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u), Const(p))
)
end

function Lux.AutoDiffInternalImpl.jacobian_vector_product_impl(
f::F, ad::AutoEnzyme, x, u) where {F}
ad = normalize_backend(True(), ad)
@assert ADTypes.mode(ad) isa ForwardMode "JVPs are only supported in forward mode."
return only(Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u)))
end

function Lux.AutoDiffInternalImpl.vector_jacobian_product_impl(
f::F, ad::AutoEnzyme, x, v, p) where {F}
ad = normalize_backend(False(), ad)
@assert ADTypes.mode(ad) isa ReverseMode "VJPs are only supported in reverse mode."
dx = zero(x)
# XXX: without the copy it overwrites the `v` with zeros
Enzyme.autodiff(ad.mode, annotate_function(ad, OOPFunctionWrapper(f)),
Duplicated(similar(v), copy(v)), Duplicated(x, dx), Const(p))
return dx
end

function Lux.AutoDiffInternalImpl.vector_jacobian_product_impl(
f::F, ad::AutoEnzyme, x, v) where {F}
ad = normalize_backend(False(), ad)
@assert ADTypes.mode(ad) isa ReverseMode "VJPs are only supported in reverse mode."
dx = zero(x)
# XXX: without the copy it overwrites the `v` with zeros
Enzyme.autodiff(ad.mode, annotate_function(ad, OOPFunctionWrapper(f)),
Duplicated(similar(v), copy(v)), Duplicated(x, dx))
return dx
end
93 changes: 93 additions & 0 deletions ext/LuxEnzymeExt/batched_autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
function Lux.AutoDiffInternalImpl.batched_jacobian_internal(
f::F, ad::AutoEnzyme, x::AbstractArray, args...) where {F}
backend = normalize_backend(True(), ad)
return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x, args...)
end

function batched_enzyme_jacobian_impl(
f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray, args...) where {G}
# We need to run the function once to get the output type. Can we use ForwardWithPrimal?
y = f_orig(x)
f = annotate_function(ad, f_orig)

@argcheck y isa AbstractArray MethodError
if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x))
throw(AssertionError("`batched_jacobian` only supports batched outputs \
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))
end
B = size(y, ndims(y))

J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]),
prod(size(x)[1:(end - 1)]), B)

chunk_size = Utils.max_enzyme_batched_chunk_size(y)
partials = ntuple(_ -> zero(x), chunk_size)

for i in 1:chunk_size:(length(x) ÷ B)
idxs = i:min(i + chunk_size - 1, length(x) ÷ B)
partials′ = make_onehot!(partials, idxs)
J_partials = only(Enzyme.autodiff(
ad.mode, f, make_batch_duplicated(x, partials′), Const.(args)...
))
for (idx, J_partial) in zip(idxs, J_partials)
J[:, :, idx] .= reshape(J_partial, :, B)
end
end

return J
end

function batched_enzyme_jacobian_impl(
f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray, args...) where {G}
# We need to run the function once to get the output type. Can we use ReverseWithPrimal?
y = f_orig(x)

@argcheck y isa AbstractArray MethodError
if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x))
throw(AssertionError("`batched_jacobian` only supports batched outputs \
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))
end
B = size(y, ndims(y))

J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]),
prod(size(x)[1:(end - 1)]), B)

chunk_size = Utils.max_enzyme_batched_chunk_size(y)
partials = ntuple(_ -> zero(y), chunk_size)
J_partials = ntuple(_ -> zero(x), chunk_size)

fn = annotate_function(ad, OOPFunctionWrapper(f_orig))
for i in 1:chunk_size:(length(y) ÷ B)
idxs = i:min(i + chunk_size - 1, length(y) ÷ B)
partials′ = make_onehot!(partials, idxs)
J_partials′ = make_zero!(J_partials, idxs)
Enzyme.autodiff(
ad.mode, fn, make_batch_duplicated(y, partials′),
make_batch_duplicated(x, J_partials′), Const.(args)...
)
for (idx, J_partial) in zip(idxs, J_partials)
J[idx, :, :] .= reshape(J_partial, :, B)
end
end

return J
end

function make_onehot!(partials, idxs)
for (idx, partial) in zip(idxs, partials)
partial .= false
partial′ = reshape(partial, :, size(partial, ndims(partial)))
partial′[idx, :] .= true
end
return partials[1:length(idxs)]
end

function make_zero!(partials, idxs)
for partial in partials
partial .= false
end
return partials[1:length(idxs)]
end

make_batch_duplicated(x, dxs) = BatchDuplicated(x, dxs)
make_batch_duplicated(x, dx::Tuple{X}) where {X} = Duplicated(x, only(dx))
3 changes: 3 additions & 0 deletions ext/LuxReactantExt/patches.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(vec(x))

# XXX: remove once EnzymeJAX supports batched AD
Utils.max_enzyme_batched_chunk_size(x::AnyTracedRArray) = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in theory this should be in now btw [thanks ofc to @jumerckx ]

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we are probably missing this on Julia end then? I got the width must be 1 error

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah


# XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint
Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g

Expand Down
1 change: 1 addition & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ function __init__()
end

is_extension_loaded(::Val) = false
is_extension_loaded(::Val{:ForwardDiff}) = true

# Preferences
include("preferences.jl")
Expand Down
64 changes: 40 additions & 24 deletions src/autodiff/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ products efficiently using mixed-mode AD.

## Backends & AD Packages

| Supported Backends | Packages Needed |
| :----------------- | :-------------- |
| `AutoZygote` | `Zygote.jl` |
| Supported Backends | Packages Needed | Notes |
| :----------------- | :-------------- | :--------------------------------------------- |
| `AutoZygote` | `Zygote.jl` | |
| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD |

!!! warning

Expand All @@ -32,12 +33,12 @@ function vector_jacobian_product(::F, backend::AbstractADType, _, __) where {F}
throw(ArgumentError("`vector_jacobian_product` is not implemented for `$(backend)`."))
end

function vector_jacobian_product(f::F, backend::AutoZygote, x, u) where {F}
if !is_extension_loaded(Val(:Zygote))
error("`Zygote.jl` must be loaded for `vector_jacobian_product` \
to work with `$(backend)`.")
for implemented_backend in (:AutoZygote, :AutoEnzyme)
@eval function vector_jacobian_product(
f::F, backend::$implemented_backend, x, u) where {F}
assert_backend_loaded(:vector_jacobian_product, backend)
return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u)
end
return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u)
end

@doc doc"""
Expand All @@ -49,9 +50,10 @@ products efficiently using mixed-mode AD.

## Backends & AD Packages

| Supported Backends | Packages Needed |
| :----------------- | :--------------- |
| `AutoForwardDiff` | |
| Supported Backends | Packages Needed | Notes |
| :----------------- | :-------------- | :--------------------------------------------- |
| `AutoForwardDiff` | | |
| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD |

!!! warning

Expand All @@ -74,8 +76,11 @@ function jacobian_vector_product(::F, backend::AbstractADType, _, __) where {F}
throw(ArgumentError("`jacobian_vector_product` is not implemented for `$(backend)`."))
end

function jacobian_vector_product(f::F, backend::AutoForwardDiff, x, u) where {F}
return AutoDiffInternalImpl.jacobian_vector_product(f, backend, x, u)
for implemented_backend in (:AutoEnzyme, :AutoForwardDiff)
@eval function jacobian_vector_product(
f::F, backend::$(implemented_backend), x, u) where {F}
return AutoDiffInternalImpl.jacobian_vector_product(f, backend, x, u)
end
end

"""
Expand All @@ -89,10 +94,11 @@ the following properties for `y = f(x)`:

## Backends & AD Packages

| Supported Backends | Packages Needed |
|:------------------ |:--------------- |
| `AutoForwardDiff` | |
| `AutoZygote` | `Zygote.jl` |
| Supported Backends | Packages Needed | Notes |
|:------------------ |:--------------- |:---------------------------------------------- |
| `AutoForwardDiff` | | |
| `AutoZygote` | `Zygote.jl` | |
| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD |

## Arguments

Expand All @@ -118,14 +124,24 @@ function batched_jacobian(::F, backend::AbstractADType, x::AbstractArray) where
throw(ArgumentError("`batched_jacobian` is not implemented for `$(backend)`."))
end

function batched_jacobian(f::F, backend::AutoForwardDiff, x::AbstractArray) where {F}
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
for implemented_backend in (AutoForwardDiff, AutoZygote, AutoEnzyme)
@eval function batched_jacobian(
f::F, backend::$(implemented_backend), x::AbstractArray) where {F}
assert_backend_loaded(:batched_jacobian, backend)
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
end
end

function batched_jacobian(f::F, backend::AutoZygote, x::AbstractArray) where {F}
if !is_extension_loaded(Val(:Zygote))
error("`Zygote.jl` must be loaded for `batched_jacobian` to work with \
`$(backend)`.")
function assert_backend_loaded(fname::Symbol, ad::AbstractADType)
return assert_backend_loaded(fname, ad, adtype_to_backend(ad))
end
function assert_backend_loaded(fname::Symbol, ad::AbstractADType, backend::Val{B}) where {B}
if !is_extension_loaded(backend)
error("$(fname) with `$(ad)` requires $(B).jl to be loaded.")
end
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
return
end

adtype_to_backend(::AutoEnzyme) = Val(:Enzyme)
adtype_to_backend(::AutoForwardDiff) = Val(:ForwardDiff)
adtype_to_backend(::AutoZygote) = Val(:Zygote)
11 changes: 11 additions & 0 deletions src/autodiff/jac_products.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,17 @@ function vector_jacobian_product(f::F, backend::AbstractADType, x, u) where {F}
return vector_jacobian_product_impl(f, backend, x, u)
end

for fType in AD_CONVERTIBLE_FUNCTIONS
@eval function vector_jacobian_product(f::$(fType), backend::AbstractADType, x, u)
f̂, y = rewrite_autodiff_call(f)
return vector_jacobian_product_impl(f̂, backend, x, u, y)
end
end

function vector_jacobian_product_impl(f::F, backend::AbstractADType, x, u, y) where {F}
return vector_jacobian_product_impl(Base.Fix2(f, y), backend, x, u)
end

# JVP Implementation
function jacobian_vector_product(f::F, backend::AbstractADType, x, u) where {F}
return jacobian_vector_product_impl(f, backend, x, u)
Expand Down
10 changes: 6 additions & 4 deletions src/autodiff/nested_autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
## Written like this to avoid dynamic dispatch from Zygote
# Input Gradient / Jacobian
function rewrite_autodiff_call(f::ComposedFunction{F, <:StatefulLuxLayer}) where {F}
(f, f.inner.ps)
return f, f.inner.ps
end
function rewrite_autodiff_call(f::ComposedFunction{<:StatefulLuxLayer, F}) where {F}
(@closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps)
return @closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps
end
rewrite_autodiff_call(f::StatefulLuxLayer) = f, f.ps

Expand All @@ -22,10 +22,12 @@ function rewrite_autodiff_call(f::Base.Fix1{<:StatefulLuxLayer})
end

## Break ambiguity
for op in [ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer},
for op in [
ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer},
ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:StatefulLuxLayer},
ComposedFunction{<:StatefulLuxLayer, <:Base.Fix1{<:StatefulLuxLayer}},
ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}}]
ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}}
]
@eval function rewrite_autodiff_call(::$op)
error("Cannot rewrite ComposedFunction with StatefulLuxLayer as inner and outer \
layers")
Expand Down
2 changes: 2 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ end
calculate_gain(::typeof(NNlib.leakyrelu), x) = typeof(x)(√(2 / (1 + x^2)))
calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4

max_enzyme_batched_chunk_size(x::AbstractArray) = min(8, length(x) ÷ Base.size(x, ndims(x)))

end

using .Utils: Utils, BoolType, IntegerType, SymbolType, make_abstract_matrix,
Expand Down
Loading
Loading