diff --git a/ext/LuxEnzymeExt/LuxEnzymeExt.jl b/ext/LuxEnzymeExt/LuxEnzymeExt.jl index 1c6b5cc88..71969be97 100644 --- a/ext/LuxEnzymeExt/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt/LuxEnzymeExt.jl @@ -1,16 +1,39 @@ module LuxEnzymeExt -using ADTypes: AutoEnzyme -using Enzyme: Enzyme, Active, Const, Duplicated +using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode +using ArgCheck: @argcheck +using ConcreteStructs: @concrete +using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated using EnzymeCore: EnzymeCore using Functors: fmap -using Setfield: @set! -using Static: False, True +using Setfield: @set!, @set +using Static: False, True, StaticBool using Lux: Lux, Utils using Lux.Training: TrainingBackendCache, TrainState using MLDataDevices: isleaf +Lux.is_extension_loaded(::Val{:Enzyme}) = true + +normalize_backend(::StaticBool, ad::AutoEnzyme) = ad +normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Forward) +normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Reverse) + +annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f +annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f) + include("training.jl") +include("autodiff.jl") +include("batched_autodiff.jl") + +@concrete struct OOPFunctionWrapper + f +end + +function (f::OOPFunctionWrapper)(y, args...) + copyto!(y, f.f(args...)) + return +end + end diff --git a/ext/LuxEnzymeExt/autodiff.jl b/ext/LuxEnzymeExt/autodiff.jl new file mode 100644 index 000000000..c3ff2413c --- /dev/null +++ b/ext/LuxEnzymeExt/autodiff.jl @@ -0,0 +1,37 @@ +function Lux.AutoDiffInternalImpl.jacobian_vector_product_impl( + f::F, ad::AutoEnzyme, x, u, p) where {F} + ad = normalize_backend(True(), ad) + @assert ADTypes.mode(ad) isa ForwardMode "JVPs are only supported in forward mode." + return only( + Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u), Const(p)) + ) +end + +function Lux.AutoDiffInternalImpl.jacobian_vector_product_impl( + f::F, ad::AutoEnzyme, x, u) where {F} + ad = normalize_backend(True(), ad) + @assert ADTypes.mode(ad) isa ForwardMode "JVPs are only supported in forward mode." + return only(Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u))) +end + +function Lux.AutoDiffInternalImpl.vector_jacobian_product_impl( + f::F, ad::AutoEnzyme, x, v, p) where {F} + ad = normalize_backend(False(), ad) + @assert ADTypes.mode(ad) isa ReverseMode "VJPs are only supported in reverse mode." + dx = zero(x) + # XXX: without the copy it overwrites the `v` with zeros + Enzyme.autodiff(ad.mode, annotate_function(ad, OOPFunctionWrapper(f)), + Duplicated(similar(v), copy(v)), Duplicated(x, dx), Const(p)) + return dx +end + +function Lux.AutoDiffInternalImpl.vector_jacobian_product_impl( + f::F, ad::AutoEnzyme, x, v) where {F} + ad = normalize_backend(False(), ad) + @assert ADTypes.mode(ad) isa ReverseMode "VJPs are only supported in reverse mode." + dx = zero(x) + # XXX: without the copy it overwrites the `v` with zeros + Enzyme.autodiff(ad.mode, annotate_function(ad, OOPFunctionWrapper(f)), + Duplicated(similar(v), copy(v)), Duplicated(x, dx)) + return dx +end diff --git a/ext/LuxEnzymeExt/batched_autodiff.jl b/ext/LuxEnzymeExt/batched_autodiff.jl new file mode 100644 index 000000000..7000d35ee --- /dev/null +++ b/ext/LuxEnzymeExt/batched_autodiff.jl @@ -0,0 +1,93 @@ +function Lux.AutoDiffInternalImpl.batched_jacobian_internal( + f::F, ad::AutoEnzyme, x::AbstractArray, args...) where {F} + backend = normalize_backend(True(), ad) + return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x, args...) +end + +function batched_enzyme_jacobian_impl( + f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray, args...) where {G} + # We need to run the function once to get the output type. Can we use ForwardWithPrimal? + y = f_orig(x) + f = annotate_function(ad, f_orig) + + @argcheck y isa AbstractArray MethodError + if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x)) + throw(AssertionError("`batched_jacobian` only supports batched outputs \ + (ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x)).")) + end + B = size(y, ndims(y)) + + J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]), + prod(size(x)[1:(end - 1)]), B) + + chunk_size = Utils.max_enzyme_batched_chunk_size(y) + partials = ntuple(_ -> zero(x), chunk_size) + + for i in 1:chunk_size:(length(x) ÷ B) + idxs = i:min(i + chunk_size - 1, length(x) ÷ B) + partials′ = make_onehot!(partials, idxs) + J_partials = only(Enzyme.autodiff( + ad.mode, f, make_batch_duplicated(x, partials′), Const.(args)... + )) + for (idx, J_partial) in zip(idxs, J_partials) + J[:, :, idx] .= reshape(J_partial, :, B) + end + end + + return J +end + +function batched_enzyme_jacobian_impl( + f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray, args...) where {G} + # We need to run the function once to get the output type. Can we use ReverseWithPrimal? + y = f_orig(x) + + @argcheck y isa AbstractArray MethodError + if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x)) + throw(AssertionError("`batched_jacobian` only supports batched outputs \ + (ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x)).")) + end + B = size(y, ndims(y)) + + J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]), + prod(size(x)[1:(end - 1)]), B) + + chunk_size = Utils.max_enzyme_batched_chunk_size(y) + partials = ntuple(_ -> zero(y), chunk_size) + J_partials = ntuple(_ -> zero(x), chunk_size) + + fn = annotate_function(ad, OOPFunctionWrapper(f_orig)) + for i in 1:chunk_size:(length(y) ÷ B) + idxs = i:min(i + chunk_size - 1, length(y) ÷ B) + partials′ = make_onehot!(partials, idxs) + J_partials′ = make_zero!(J_partials, idxs) + Enzyme.autodiff( + ad.mode, fn, make_batch_duplicated(y, partials′), + make_batch_duplicated(x, J_partials′), Const.(args)... + ) + for (idx, J_partial) in zip(idxs, J_partials) + J[idx, :, :] .= reshape(J_partial, :, B) + end + end + + return J +end + +function make_onehot!(partials, idxs) + for (idx, partial) in zip(idxs, partials) + partial .= false + partial′ = reshape(partial, :, size(partial, ndims(partial))) + partial′[idx, :] .= true + end + return partials[1:length(idxs)] +end + +function make_zero!(partials, idxs) + for partial in partials + partial .= false + end + return partials[1:length(idxs)] +end + +make_batch_duplicated(x, dxs) = BatchDuplicated(x, dxs) +make_batch_duplicated(x, dx::Tuple{X}) where {X} = Duplicated(x, only(dx)) diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index 6d79f2b60..8f6d4f511 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -1,5 +1,8 @@ Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(vec(x)) +# XXX: remove once EnzymeJAX supports batched AD +Utils.max_enzyme_batched_chunk_size(x::AnyTracedRArray) = 1 + # XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g diff --git a/src/Lux.jl b/src/Lux.jl index ceda3df25..942367902 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -49,6 +49,7 @@ function __init__() end is_extension_loaded(::Val) = false +is_extension_loaded(::Val{:ForwardDiff}) = true # Preferences include("preferences.jl") diff --git a/src/autodiff/api.jl b/src/autodiff/api.jl index 6db1e38da..3bc1a907f 100644 --- a/src/autodiff/api.jl +++ b/src/autodiff/api.jl @@ -7,9 +7,10 @@ products efficiently using mixed-mode AD. ## Backends & AD Packages -| Supported Backends | Packages Needed | -| :----------------- | :-------------- | -| `AutoZygote` | `Zygote.jl` | +| Supported Backends | Packages Needed | Notes | +| :----------------- | :-------------- | :--------------------------------------------- | +| `AutoZygote` | `Zygote.jl` | | +| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD | !!! warning @@ -32,12 +33,12 @@ function vector_jacobian_product(::F, backend::AbstractADType, _, __) where {F} throw(ArgumentError("`vector_jacobian_product` is not implemented for `$(backend)`.")) end -function vector_jacobian_product(f::F, backend::AutoZygote, x, u) where {F} - if !is_extension_loaded(Val(:Zygote)) - error("`Zygote.jl` must be loaded for `vector_jacobian_product` \ - to work with `$(backend)`.") +for implemented_backend in (:AutoZygote, :AutoEnzyme) + @eval function vector_jacobian_product( + f::F, backend::$implemented_backend, x, u) where {F} + assert_backend_loaded(:vector_jacobian_product, backend) + return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u) end - return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u) end @doc doc""" @@ -49,9 +50,10 @@ products efficiently using mixed-mode AD. ## Backends & AD Packages -| Supported Backends | Packages Needed | -| :----------------- | :--------------- | -| `AutoForwardDiff` | | +| Supported Backends | Packages Needed | Notes | +| :----------------- | :-------------- | :--------------------------------------------- | +| `AutoForwardDiff` | | | +| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD | !!! warning @@ -74,8 +76,11 @@ function jacobian_vector_product(::F, backend::AbstractADType, _, __) where {F} throw(ArgumentError("`jacobian_vector_product` is not implemented for `$(backend)`.")) end -function jacobian_vector_product(f::F, backend::AutoForwardDiff, x, u) where {F} - return AutoDiffInternalImpl.jacobian_vector_product(f, backend, x, u) +for implemented_backend in (:AutoEnzyme, :AutoForwardDiff) + @eval function jacobian_vector_product( + f::F, backend::$(implemented_backend), x, u) where {F} + return AutoDiffInternalImpl.jacobian_vector_product(f, backend, x, u) + end end """ @@ -89,10 +94,11 @@ the following properties for `y = f(x)`: ## Backends & AD Packages -| Supported Backends | Packages Needed | -|:------------------ |:--------------- | -| `AutoForwardDiff` | | -| `AutoZygote` | `Zygote.jl` | +| Supported Backends | Packages Needed | Notes | +|:------------------ |:--------------- |:---------------------------------------------- | +| `AutoForwardDiff` | | | +| `AutoZygote` | `Zygote.jl` | | +| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD | ## Arguments @@ -118,14 +124,24 @@ function batched_jacobian(::F, backend::AbstractADType, x::AbstractArray) where throw(ArgumentError("`batched_jacobian` is not implemented for `$(backend)`.")) end -function batched_jacobian(f::F, backend::AutoForwardDiff, x::AbstractArray) where {F} - return AutoDiffInternalImpl.batched_jacobian(f, backend, x) +for implemented_backend in (AutoForwardDiff, AutoZygote, AutoEnzyme) + @eval function batched_jacobian( + f::F, backend::$(implemented_backend), x::AbstractArray) where {F} + assert_backend_loaded(:batched_jacobian, backend) + return AutoDiffInternalImpl.batched_jacobian(f, backend, x) + end end -function batched_jacobian(f::F, backend::AutoZygote, x::AbstractArray) where {F} - if !is_extension_loaded(Val(:Zygote)) - error("`Zygote.jl` must be loaded for `batched_jacobian` to work with \ - `$(backend)`.") +function assert_backend_loaded(fname::Symbol, ad::AbstractADType) + return assert_backend_loaded(fname, ad, adtype_to_backend(ad)) +end +function assert_backend_loaded(fname::Symbol, ad::AbstractADType, backend::Val{B}) where {B} + if !is_extension_loaded(backend) + error("$(fname) with `$(ad)` requires $(B).jl to be loaded.") end - return AutoDiffInternalImpl.batched_jacobian(f, backend, x) + return end + +adtype_to_backend(::AutoEnzyme) = Val(:Enzyme) +adtype_to_backend(::AutoForwardDiff) = Val(:ForwardDiff) +adtype_to_backend(::AutoZygote) = Val(:Zygote) diff --git a/src/autodiff/jac_products.jl b/src/autodiff/jac_products.jl index 97785f144..8c5d0544b 100644 --- a/src/autodiff/jac_products.jl +++ b/src/autodiff/jac_products.jl @@ -3,6 +3,17 @@ function vector_jacobian_product(f::F, backend::AbstractADType, x, u) where {F} return vector_jacobian_product_impl(f, backend, x, u) end +for fType in AD_CONVERTIBLE_FUNCTIONS + @eval function vector_jacobian_product(f::$(fType), backend::AbstractADType, x, u) + f̂, y = rewrite_autodiff_call(f) + return vector_jacobian_product_impl(f̂, backend, x, u, y) + end +end + +function vector_jacobian_product_impl(f::F, backend::AbstractADType, x, u, y) where {F} + return vector_jacobian_product_impl(Base.Fix2(f, y), backend, x, u) +end + # JVP Implementation function jacobian_vector_product(f::F, backend::AbstractADType, x, u) where {F} return jacobian_vector_product_impl(f, backend, x, u) diff --git a/src/autodiff/nested_autodiff.jl b/src/autodiff/nested_autodiff.jl index dfc94ad6f..6467fd1f3 100644 --- a/src/autodiff/nested_autodiff.jl +++ b/src/autodiff/nested_autodiff.jl @@ -1,10 +1,10 @@ ## Written like this to avoid dynamic dispatch from Zygote # Input Gradient / Jacobian function rewrite_autodiff_call(f::ComposedFunction{F, <:StatefulLuxLayer}) where {F} - (f, f.inner.ps) + return f, f.inner.ps end function rewrite_autodiff_call(f::ComposedFunction{<:StatefulLuxLayer, F}) where {F} - (@closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps) + return @closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps end rewrite_autodiff_call(f::StatefulLuxLayer) = f, f.ps @@ -22,10 +22,12 @@ function rewrite_autodiff_call(f::Base.Fix1{<:StatefulLuxLayer}) end ## Break ambiguity -for op in [ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer}, +for op in [ + ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer}, ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:StatefulLuxLayer}, ComposedFunction{<:StatefulLuxLayer, <:Base.Fix1{<:StatefulLuxLayer}}, - ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}}] + ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}} +] @eval function rewrite_autodiff_call(::$op) error("Cannot rewrite ComposedFunction with StatefulLuxLayer as inner and outer \ layers") diff --git a/src/utils.jl b/src/utils.jl index 99429f5c1..7ccdcc6cb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -228,6 +228,8 @@ end calculate_gain(::typeof(NNlib.leakyrelu), x) = typeof(x)(√(2 / (1 + x^2))) calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4 +max_enzyme_batched_chunk_size(x::AbstractArray) = min(8, length(x) ÷ Base.size(x, ndims(x))) + end using .Utils: Utils, BoolType, IntegerType, SymbolType, make_abstract_matrix, diff --git a/test/autodiff/batched_autodiff_tests.jl b/test/autodiff/batched_autodiff_tests.jl index 633725142..326aa155a 100644 --- a/test/autodiff/batched_autodiff_tests.jl +++ b/test/autodiff/batched_autodiff_tests.jl @@ -1,16 +1,20 @@ @testitem "Batched Jacobian" setup=[SharedTestSetup] tags=[:autodiff] begin - using ComponentArrays, ForwardDiff, Zygote + using ComponentArrays, ForwardDiff, Zygote, ADTypes rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES models = ( - Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), - Conv((3, 3), 4 => 2, gelu; pad=SamePad()), FlattenLayer(), Dense(18 => 2)), - Chain(Dense(2, 4, gelu), Dense(4, 2))) + Chain( + Conv((3, 3), 2 => 4, gelu; pad=SamePad()), + Conv((3, 3), 4 => 2, gelu; pad=SamePad()), + FlattenLayer(), Dense(18 => 2) + ), + Chain(Dense(2, 4, gelu), Dense(4, 2)) + ) Xs = (aType(randn(rng, Float32, 3, 3, 2, 4)), aType(randn(rng, Float32, 2, 4))) - for (model, X) in zip(models, Xs) + for (i, (model, X)) in enumerate(zip(models, Xs)) ps, st = Lux.setup(rng, model) |> dev smodel = StatefulLuxLayer{true}(model, ps, st) @@ -18,7 +22,20 @@ ForwardDiff.jacobian(smodel, X) end - @testset "$(backend)" for backend in (AutoZygote(), AutoForwardDiff()) + @testset for backend in ( + AutoZygote(), AutoForwardDiff(), + AutoEnzyme(; + mode=Enzyme.Forward, function_annotation=Enzyme.Const + ), + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const + ) + ) + # Forward rules for Enzyme is currently not implemented for several Ops + i == 1 && backend isa AutoEnzyme && + ADTypes.mode(backend) isa ADTypes.ForwardMode && continue + J2 = allow_unstable() do batched_jacobian(smodel, backend, X) end @@ -40,7 +57,14 @@ end @testset "Issue #636 Chunksize Specialization" begin - for N in (2, 4, 8, 11, 12, 50, 51), backend in (AutoZygote(), AutoForwardDiff()) + for N in (2, 4, 8, 11, 12, 50, 51), + backend in ( + AutoZygote(), AutoForwardDiff(), AutoEnzyme(), + AutoEnzyme(; mode=Enzyme.Reverse) + ) + + ongpu && backend isa AutoEnzyme && continue + model = @compact(; potential=Dense(N => N, gelu), backend=backend) do x @return allow_unstable() do batched_jacobian(potential, backend, x) @@ -78,6 +102,13 @@ end @test Jx_zygote ≈ Jx_true + if !ongpu + Jx_enzyme = allow_unstable() do + batched_jacobian(ftest, AutoEnzyme(), x) + end + @test Jx_enzyme ≈ Jx_true + end + fincorrect(x) = x[:, 1] x = reshape(Float32.(1:6), 2, 3) |> dev diff --git a/test/autodiff/nested_autodiff_tests.jl b/test/autodiff/nested_autodiff_tests.jl index 6ebe189c4..e328df283 100644 --- a/test/autodiff/nested_autodiff_tests.jl +++ b/test/autodiff/nested_autodiff_tests.jl @@ -267,10 +267,8 @@ end end @test_gradients(__f, x, - ps; - atol=1.0f-3, - rtol=1.0f-3, - skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + ps; atol=1.0f-3, + rtol=1.0f-3, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end end end @@ -409,6 +407,5 @@ end end @test_gradients(__f, x, ps; atol=1.0f-3, - rtol=1.0f-3, - skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + rtol=1.0f-3, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end diff --git a/test/runtests.jl b/test/runtests.jl index 83c6eaca0..f2213282e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,7 +74,7 @@ using Lux @test_throws ErrorException vector_jacobian_product( x -> x, AutoZygote(), rand(2), rand(2)) - @test_throws ArgumentError batched_jacobian(x -> x, AutoEnzyme(), rand(2, 2)) + @test_throws ArgumentError batched_jacobian(x -> x, AutoTracker(), rand(2, 2)) @test_throws ErrorException batched_jacobian(x -> x, AutoZygote(), rand(2, 2)) end