From 1c57aaa00e2c83f0c01ecfd6a398e50cecb1df0e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sun, 22 Sep 2024 20:35:12 -0400 Subject: [PATCH 1/6] feat: add forward mode batched enzyme jacobian --- ext/LuxEnzymeExt/LuxEnzymeExt.jl | 20 ++++++++--- ext/LuxEnzymeExt/batched_autodiff.jl | 50 ++++++++++++++++++++++++++++ src/Lux.jl | 1 + src/autodiff/api.jl | 38 ++++++++++++--------- 4 files changed, 90 insertions(+), 19 deletions(-) create mode 100644 ext/LuxEnzymeExt/batched_autodiff.jl diff --git a/ext/LuxEnzymeExt/LuxEnzymeExt.jl b/ext/LuxEnzymeExt/LuxEnzymeExt.jl index 1c6b5cc88..512c31740 100644 --- a/ext/LuxEnzymeExt/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt/LuxEnzymeExt.jl @@ -1,16 +1,28 @@ module LuxEnzymeExt -using ADTypes: AutoEnzyme -using Enzyme: Enzyme, Active, Const, Duplicated +using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode +using ArgCheck: @argcheck +using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated using EnzymeCore: EnzymeCore using Functors: fmap -using Setfield: @set! -using Static: False, True +using Setfield: @set!, @set +using Static: False, True, StaticBool using Lux: Lux, Utils using Lux.Training: TrainingBackendCache, TrainState using MLDataDevices: isleaf +Lux.is_extension_loaded(::Val{:Enzyme}) = true + +normalize_backend(::StaticBool, ad::AutoEnzyme) = ad +normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Forward) +normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Reverse) + +annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f +annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f) + include("training.jl") +include("batched_autodiff.jl") + end diff --git a/ext/LuxEnzymeExt/batched_autodiff.jl b/ext/LuxEnzymeExt/batched_autodiff.jl new file mode 100644 index 000000000..5d1d1a0c0 --- /dev/null +++ b/ext/LuxEnzymeExt/batched_autodiff.jl @@ -0,0 +1,50 @@ +function Lux.AutoDiffInternalImpl.batched_jacobian_impl( + f::F, ad::AutoEnzyme, x::AbstractArray) where {F} + backend = normalize_backend(True(), ad) + return batched_enzyme_jacobian_impl( + annotate_function(ad, f), backend, ADTypes.mode(backend), x) +end + +function batched_enzyme_jacobian_impl( + f::F, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {F} + # We need to run the function once to get the output type. Can we use ForwardWithPrimal? + y = f(x) + + @argcheck y isa AbstractArray MethodError + if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x)) + throw(AssertionError("`batched_jacobian` only supports batched outputs \ + (ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x)).")) + end + B = size(y, ndims(y)) + + J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]), + prod(size(x)[1:(end - 1)]), B) + + chunk_size = min(8, length(y) ÷ B) + partials = ntuple(_ -> zero(x), chunk_size) + + for i in 1:chunk_size:(length(x) ÷ B) + idxs = i:min(i + chunk_size - 1, length(x) ÷ B) + partials′ = make_onehot!(partials, idxs) + J_partials = only(Enzyme.autodiff(ad.mode, f, BatchDuplicated(x, partials′))) + for (idx, J_partial) in zip(idxs, J_partials) + copyto!(view(J, :, idx, :), reshape(J_partial, :, B)) + end + end + + return J +end + +function batched_enzyme_jacobian_impl( + f::F, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {F} + error("reverse mode is not supported yet") +end + +function make_onehot!(partials, idxs) + for (idx, partial) in zip(idxs, partials) + partial′ = reshape(partial, :, size(partial, ndims(partial))) + fill!(partial′, false) + fill!(view(partial′, idx, :), true) + end + return partials[1:length(idxs)] +end diff --git a/src/Lux.jl b/src/Lux.jl index ceda3df25..942367902 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -49,6 +49,7 @@ function __init__() end is_extension_loaded(::Val) = false +is_extension_loaded(::Val{:ForwardDiff}) = true # Preferences include("preferences.jl") diff --git a/src/autodiff/api.jl b/src/autodiff/api.jl index 6db1e38da..1077b6710 100644 --- a/src/autodiff/api.jl +++ b/src/autodiff/api.jl @@ -33,10 +33,7 @@ function vector_jacobian_product(::F, backend::AbstractADType, _, __) where {F} end function vector_jacobian_product(f::F, backend::AutoZygote, x, u) where {F} - if !is_extension_loaded(Val(:Zygote)) - error("`Zygote.jl` must be loaded for `vector_jacobian_product` \ - to work with `$(backend)`.") - end + assert_backend_loaded(:vector_jacobian_product, backend) return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u) end @@ -89,10 +86,11 @@ the following properties for `y = f(x)`: ## Backends & AD Packages -| Supported Backends | Packages Needed | -|:------------------ |:--------------- | -| `AutoForwardDiff` | | -| `AutoZygote` | `Zygote.jl` | +| Supported Backends | Packages Needed | Note | +|:------------------ |:--------------- |:---------------------------------------------- | +| `AutoForwardDiff` | | | +| `AutoZygote` | `Zygote.jl` | | +| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD | ## Arguments @@ -118,14 +116,24 @@ function batched_jacobian(::F, backend::AbstractADType, x::AbstractArray) where throw(ArgumentError("`batched_jacobian` is not implemented for `$(backend)`.")) end -function batched_jacobian(f::F, backend::AutoForwardDiff, x::AbstractArray) where {F} - return AutoDiffInternalImpl.batched_jacobian(f, backend, x) +for implemented_backend in (AutoForwardDiff, AutoZygote, AutoEnzyme) + @eval function batched_jacobian( + f::F, backend::$(implemented_backend), x::AbstractArray) where {F} + assert_backend_loaded(:batched_jacobian, backend) + return AutoDiffInternalImpl.batched_jacobian(f, backend, x) + end end -function batched_jacobian(f::F, backend::AutoZygote, x::AbstractArray) where {F} - if !is_extension_loaded(Val(:Zygote)) - error("`Zygote.jl` must be loaded for `batched_jacobian` to work with \ - `$(backend)`.") +function assert_backend_loaded(fname::Symbol, ad::AbstractADType) + return assert_backend_loaded(fname, ad, adtype_to_backend(ad)) +end +function assert_backend_loaded(fname::Symbol, ad::AbstractADType, backend::Val{B}) where {B} + if !is_extension_loaded(backend) + error("$(fname) with `$(ad)` requires $(B).jl to be loaded.") end - return AutoDiffInternalImpl.batched_jacobian(f, backend, x) + return end + +adtype_to_backend(::AutoEnzyme) = Val(:Enzyme) +adtype_to_backend(::AutoForwardDiff) = Val(:ForwardDiff) +adtype_to_backend(::AutoZygote) = Val(:Zygote) From 5985015e71d34ceaef28afde5b08c4906590c0f0 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 25 Nov 2024 00:00:29 -0500 Subject: [PATCH 2/6] feat: add reverse mode batched enzyme jacobian --- ext/LuxEnzymeExt/LuxEnzymeExt.jl | 5 ++- ext/LuxEnzymeExt/batched_autodiff.jl | 58 +++++++++++++++++++++++++--- src/autodiff/api.jl | 2 +- test/runtests.jl | 2 +- 4 files changed, 57 insertions(+), 10 deletions(-) diff --git a/ext/LuxEnzymeExt/LuxEnzymeExt.jl b/ext/LuxEnzymeExt/LuxEnzymeExt.jl index 512c31740..01ffe068c 100644 --- a/ext/LuxEnzymeExt/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt/LuxEnzymeExt.jl @@ -2,6 +2,7 @@ module LuxEnzymeExt using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode using ArgCheck: @argcheck +using ConcreteStructs: @concrete using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated using EnzymeCore: EnzymeCore using Functors: fmap @@ -15,8 +16,8 @@ using MLDataDevices: isleaf Lux.is_extension_loaded(::Val{:Enzyme}) = true normalize_backend(::StaticBool, ad::AutoEnzyme) = ad -normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Forward) -normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Reverse) +normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Forward) +normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Reverse) annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f) diff --git a/ext/LuxEnzymeExt/batched_autodiff.jl b/ext/LuxEnzymeExt/batched_autodiff.jl index 5d1d1a0c0..b116b396b 100644 --- a/ext/LuxEnzymeExt/batched_autodiff.jl +++ b/ext/LuxEnzymeExt/batched_autodiff.jl @@ -1,14 +1,14 @@ function Lux.AutoDiffInternalImpl.batched_jacobian_impl( f::F, ad::AutoEnzyme, x::AbstractArray) where {F} backend = normalize_backend(True(), ad) - return batched_enzyme_jacobian_impl( - annotate_function(ad, f), backend, ADTypes.mode(backend), x) + return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x) end function batched_enzyme_jacobian_impl( - f::F, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {F} + f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {G} # We need to run the function once to get the output type. Can we use ForwardWithPrimal? - y = f(x) + y = f_orig(x) + f = annotate_function(ad, f_orig) @argcheck y isa AbstractArray MethodError if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x)) @@ -36,8 +36,38 @@ function batched_enzyme_jacobian_impl( end function batched_enzyme_jacobian_impl( - f::F, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {F} - error("reverse mode is not supported yet") + f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {G} + # We need to run the function once to get the output type. Can we use ReverseWithPrimal? + y = f_orig(x) + + @argcheck y isa AbstractArray MethodError + if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x)) + throw(AssertionError("`batched_jacobian` only supports batched outputs \ + (ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x)).")) + end + B = size(y, ndims(y)) + + J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]), + prod(size(x)[1:(end - 1)]), B) + + chunk_size = min(8, length(x) ÷ B) + partials = ntuple(_ -> zero(y), chunk_size) + J_partials = ntuple(_ -> zero(x), chunk_size) + + fn = annotate_function(ad, OOPFunctionWrapper(f_orig)) + for i in 1:chunk_size:(length(y) ÷ B) + idxs = i:min(i + chunk_size - 1, length(y) ÷ B) + partials′ = make_onehot!(partials, idxs) + J_partials′ = make_zero!(J_partials, idxs) + Enzyme.autodiff( + ad.mode, fn, BatchDuplicated(y, partials′), BatchDuplicated(x, J_partials′) + ) + for (idx, J_partial) in zip(idxs, J_partials) + copyto!(view(J, idx, :, :), reshape(J_partial, :, B)) + end + end + + return J end function make_onehot!(partials, idxs) @@ -48,3 +78,19 @@ function make_onehot!(partials, idxs) end return partials[1:length(idxs)] end + +function make_zero!(partials, idxs) + for partial in partials + fill!(partial, false) + end + return partials[1:length(idxs)] +end + +@concrete struct OOPFunctionWrapper + f +end + +function (f::OOPFunctionWrapper)(y, x) + copyto!(y, f.f(x)) + return +end diff --git a/src/autodiff/api.jl b/src/autodiff/api.jl index 1077b6710..95625f853 100644 --- a/src/autodiff/api.jl +++ b/src/autodiff/api.jl @@ -86,7 +86,7 @@ the following properties for `y = f(x)`: ## Backends & AD Packages -| Supported Backends | Packages Needed | Note | +| Supported Backends | Packages Needed | Notes | |:------------------ |:--------------- |:---------------------------------------------- | | `AutoForwardDiff` | | | | `AutoZygote` | `Zygote.jl` | | diff --git a/test/runtests.jl b/test/runtests.jl index 83c6eaca0..f2213282e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -74,7 +74,7 @@ using Lux @test_throws ErrorException vector_jacobian_product( x -> x, AutoZygote(), rand(2), rand(2)) - @test_throws ArgumentError batched_jacobian(x -> x, AutoEnzyme(), rand(2, 2)) + @test_throws ArgumentError batched_jacobian(x -> x, AutoTracker(), rand(2, 2)) @test_throws ErrorException batched_jacobian(x -> x, AutoZygote(), rand(2, 2)) end From 0c1770e40dfdaa15f07f2fb63ac0767e0f59b5f3 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 25 Nov 2024 14:09:20 -0500 Subject: [PATCH 3/6] feat: add vjp and jvp for Enzyme --- ext/LuxEnzymeExt/LuxEnzymeExt.jl | 10 ++++++++ ext/LuxEnzymeExt/autodiff.jl | 37 ++++++++++++++++++++++++++++ ext/LuxEnzymeExt/batched_autodiff.jl | 9 ------- src/autodiff/api.jl | 30 +++++++++++++--------- src/autodiff/jac_products.jl | 11 +++++++++ 5 files changed, 77 insertions(+), 20 deletions(-) create mode 100644 ext/LuxEnzymeExt/autodiff.jl diff --git a/ext/LuxEnzymeExt/LuxEnzymeExt.jl b/ext/LuxEnzymeExt/LuxEnzymeExt.jl index 01ffe068c..71969be97 100644 --- a/ext/LuxEnzymeExt/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt/LuxEnzymeExt.jl @@ -24,6 +24,16 @@ annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f) include("training.jl") +include("autodiff.jl") include("batched_autodiff.jl") +@concrete struct OOPFunctionWrapper + f +end + +function (f::OOPFunctionWrapper)(y, args...) + copyto!(y, f.f(args...)) + return +end + end diff --git a/ext/LuxEnzymeExt/autodiff.jl b/ext/LuxEnzymeExt/autodiff.jl new file mode 100644 index 000000000..c3ff2413c --- /dev/null +++ b/ext/LuxEnzymeExt/autodiff.jl @@ -0,0 +1,37 @@ +function Lux.AutoDiffInternalImpl.jacobian_vector_product_impl( + f::F, ad::AutoEnzyme, x, u, p) where {F} + ad = normalize_backend(True(), ad) + @assert ADTypes.mode(ad) isa ForwardMode "JVPs are only supported in forward mode." + return only( + Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u), Const(p)) + ) +end + +function Lux.AutoDiffInternalImpl.jacobian_vector_product_impl( + f::F, ad::AutoEnzyme, x, u) where {F} + ad = normalize_backend(True(), ad) + @assert ADTypes.mode(ad) isa ForwardMode "JVPs are only supported in forward mode." + return only(Enzyme.autodiff(ad.mode, annotate_function(ad, f), Duplicated(x, u))) +end + +function Lux.AutoDiffInternalImpl.vector_jacobian_product_impl( + f::F, ad::AutoEnzyme, x, v, p) where {F} + ad = normalize_backend(False(), ad) + @assert ADTypes.mode(ad) isa ReverseMode "VJPs are only supported in reverse mode." + dx = zero(x) + # XXX: without the copy it overwrites the `v` with zeros + Enzyme.autodiff(ad.mode, annotate_function(ad, OOPFunctionWrapper(f)), + Duplicated(similar(v), copy(v)), Duplicated(x, dx), Const(p)) + return dx +end + +function Lux.AutoDiffInternalImpl.vector_jacobian_product_impl( + f::F, ad::AutoEnzyme, x, v) where {F} + ad = normalize_backend(False(), ad) + @assert ADTypes.mode(ad) isa ReverseMode "VJPs are only supported in reverse mode." + dx = zero(x) + # XXX: without the copy it overwrites the `v` with zeros + Enzyme.autodiff(ad.mode, annotate_function(ad, OOPFunctionWrapper(f)), + Duplicated(similar(v), copy(v)), Duplicated(x, dx)) + return dx +end diff --git a/ext/LuxEnzymeExt/batched_autodiff.jl b/ext/LuxEnzymeExt/batched_autodiff.jl index b116b396b..32cf8c2c3 100644 --- a/ext/LuxEnzymeExt/batched_autodiff.jl +++ b/ext/LuxEnzymeExt/batched_autodiff.jl @@ -85,12 +85,3 @@ function make_zero!(partials, idxs) end return partials[1:length(idxs)] end - -@concrete struct OOPFunctionWrapper - f -end - -function (f::OOPFunctionWrapper)(y, x) - copyto!(y, f.f(x)) - return -end diff --git a/src/autodiff/api.jl b/src/autodiff/api.jl index 95625f853..3bc1a907f 100644 --- a/src/autodiff/api.jl +++ b/src/autodiff/api.jl @@ -7,9 +7,10 @@ products efficiently using mixed-mode AD. ## Backends & AD Packages -| Supported Backends | Packages Needed | -| :----------------- | :-------------- | -| `AutoZygote` | `Zygote.jl` | +| Supported Backends | Packages Needed | Notes | +| :----------------- | :-------------- | :--------------------------------------------- | +| `AutoZygote` | `Zygote.jl` | | +| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD | !!! warning @@ -32,9 +33,12 @@ function vector_jacobian_product(::F, backend::AbstractADType, _, __) where {F} throw(ArgumentError("`vector_jacobian_product` is not implemented for `$(backend)`.")) end -function vector_jacobian_product(f::F, backend::AutoZygote, x, u) where {F} - assert_backend_loaded(:vector_jacobian_product, backend) - return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u) +for implemented_backend in (:AutoZygote, :AutoEnzyme) + @eval function vector_jacobian_product( + f::F, backend::$implemented_backend, x, u) where {F} + assert_backend_loaded(:vector_jacobian_product, backend) + return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u) + end end @doc doc""" @@ -46,9 +50,10 @@ products efficiently using mixed-mode AD. ## Backends & AD Packages -| Supported Backends | Packages Needed | -| :----------------- | :--------------- | -| `AutoForwardDiff` | | +| Supported Backends | Packages Needed | Notes | +| :----------------- | :-------------- | :--------------------------------------------- | +| `AutoForwardDiff` | | | +| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD | !!! warning @@ -71,8 +76,11 @@ function jacobian_vector_product(::F, backend::AbstractADType, _, __) where {F} throw(ArgumentError("`jacobian_vector_product` is not implemented for `$(backend)`.")) end -function jacobian_vector_product(f::F, backend::AutoForwardDiff, x, u) where {F} - return AutoDiffInternalImpl.jacobian_vector_product(f, backend, x, u) +for implemented_backend in (:AutoEnzyme, :AutoForwardDiff) + @eval function jacobian_vector_product( + f::F, backend::$(implemented_backend), x, u) where {F} + return AutoDiffInternalImpl.jacobian_vector_product(f, backend, x, u) + end end """ diff --git a/src/autodiff/jac_products.jl b/src/autodiff/jac_products.jl index 97785f144..8c5d0544b 100644 --- a/src/autodiff/jac_products.jl +++ b/src/autodiff/jac_products.jl @@ -3,6 +3,17 @@ function vector_jacobian_product(f::F, backend::AbstractADType, x, u) where {F} return vector_jacobian_product_impl(f, backend, x, u) end +for fType in AD_CONVERTIBLE_FUNCTIONS + @eval function vector_jacobian_product(f::$(fType), backend::AbstractADType, x, u) + f̂, y = rewrite_autodiff_call(f) + return vector_jacobian_product_impl(f̂, backend, x, u, y) + end +end + +function vector_jacobian_product_impl(f::F, backend::AbstractADType, x, u, y) where {F} + return vector_jacobian_product_impl(Base.Fix2(f, y), backend, x, u) +end + # JVP Implementation function jacobian_vector_product(f::F, backend::AbstractADType, x, u) where {F} return jacobian_vector_product_impl(f, backend, x, u) From c15ccb24110dde4f5b72f173f7910587991d2c82 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 25 Nov 2024 14:28:05 -0500 Subject: [PATCH 4/6] fix: avoid closures in batched_jacobian --- ext/LuxEnzymeExt/batched_autodiff.jl | 16 +++++++++------- src/autodiff/nested_autodiff.jl | 10 ++++++---- 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/ext/LuxEnzymeExt/batched_autodiff.jl b/ext/LuxEnzymeExt/batched_autodiff.jl index 32cf8c2c3..e73dcad2a 100644 --- a/ext/LuxEnzymeExt/batched_autodiff.jl +++ b/ext/LuxEnzymeExt/batched_autodiff.jl @@ -1,11 +1,11 @@ -function Lux.AutoDiffInternalImpl.batched_jacobian_impl( - f::F, ad::AutoEnzyme, x::AbstractArray) where {F} +function Lux.AutoDiffInternalImpl.batched_jacobian_internal( + f::F, ad::AutoEnzyme, x::AbstractArray, args...) where {F} backend = normalize_backend(True(), ad) - return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x) + return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x, args...) end function batched_enzyme_jacobian_impl( - f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {G} + f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray, args...) where {G} # We need to run the function once to get the output type. Can we use ForwardWithPrimal? y = f_orig(x) f = annotate_function(ad, f_orig) @@ -26,7 +26,8 @@ function batched_enzyme_jacobian_impl( for i in 1:chunk_size:(length(x) ÷ B) idxs = i:min(i + chunk_size - 1, length(x) ÷ B) partials′ = make_onehot!(partials, idxs) - J_partials = only(Enzyme.autodiff(ad.mode, f, BatchDuplicated(x, partials′))) + J_partials = only(Enzyme.autodiff( + ad.mode, f, BatchDuplicated(x, partials′), Const.(args)...)) for (idx, J_partial) in zip(idxs, J_partials) copyto!(view(J, :, idx, :), reshape(J_partial, :, B)) end @@ -36,7 +37,7 @@ function batched_enzyme_jacobian_impl( end function batched_enzyme_jacobian_impl( - f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {G} + f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray, args...) where {G} # We need to run the function once to get the output type. Can we use ReverseWithPrimal? y = f_orig(x) @@ -60,7 +61,8 @@ function batched_enzyme_jacobian_impl( partials′ = make_onehot!(partials, idxs) J_partials′ = make_zero!(J_partials, idxs) Enzyme.autodiff( - ad.mode, fn, BatchDuplicated(y, partials′), BatchDuplicated(x, J_partials′) + ad.mode, fn, BatchDuplicated(y, partials′), + BatchDuplicated(x, J_partials′), Const.(args)... ) for (idx, J_partial) in zip(idxs, J_partials) copyto!(view(J, idx, :, :), reshape(J_partial, :, B)) diff --git a/src/autodiff/nested_autodiff.jl b/src/autodiff/nested_autodiff.jl index dfc94ad6f..6467fd1f3 100644 --- a/src/autodiff/nested_autodiff.jl +++ b/src/autodiff/nested_autodiff.jl @@ -1,10 +1,10 @@ ## Written like this to avoid dynamic dispatch from Zygote # Input Gradient / Jacobian function rewrite_autodiff_call(f::ComposedFunction{F, <:StatefulLuxLayer}) where {F} - (f, f.inner.ps) + return f, f.inner.ps end function rewrite_autodiff_call(f::ComposedFunction{<:StatefulLuxLayer, F}) where {F} - (@closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps) + return @closure((x, ps)->f.outer(f.inner(x), ps)), f.outer.ps end rewrite_autodiff_call(f::StatefulLuxLayer) = f, f.ps @@ -22,10 +22,12 @@ function rewrite_autodiff_call(f::Base.Fix1{<:StatefulLuxLayer}) end ## Break ambiguity -for op in [ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer}, +for op in [ + ComposedFunction{<:StatefulLuxLayer, <:StatefulLuxLayer}, ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:StatefulLuxLayer}, ComposedFunction{<:StatefulLuxLayer, <:Base.Fix1{<:StatefulLuxLayer}}, - ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}}] + ComposedFunction{<:Base.Fix1{<:StatefulLuxLayer}, <:Base.Fix1{<:StatefulLuxLayer}} +] @eval function rewrite_autodiff_call(::$op) error("Cannot rewrite ComposedFunction with StatefulLuxLayer as inner and outer \ layers") From d045b233bf83a28e75e14469b9b4ff6138908ff1 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Mon, 25 Nov 2024 15:35:26 -0500 Subject: [PATCH 5/6] test: add batched jacobian tests for enzyme --- test/autodiff/batched_autodiff_tests.jl | 45 +++++++++++++++++++++---- test/autodiff/nested_autodiff_tests.jl | 9 ++--- 2 files changed, 41 insertions(+), 13 deletions(-) diff --git a/test/autodiff/batched_autodiff_tests.jl b/test/autodiff/batched_autodiff_tests.jl index 633725142..326aa155a 100644 --- a/test/autodiff/batched_autodiff_tests.jl +++ b/test/autodiff/batched_autodiff_tests.jl @@ -1,16 +1,20 @@ @testitem "Batched Jacobian" setup=[SharedTestSetup] tags=[:autodiff] begin - using ComponentArrays, ForwardDiff, Zygote + using ComponentArrays, ForwardDiff, Zygote, ADTypes rng = StableRNG(12345) @testset "$mode" for (mode, aType, dev, ongpu) in MODES models = ( - Chain(Conv((3, 3), 2 => 4, gelu; pad=SamePad()), - Conv((3, 3), 4 => 2, gelu; pad=SamePad()), FlattenLayer(), Dense(18 => 2)), - Chain(Dense(2, 4, gelu), Dense(4, 2))) + Chain( + Conv((3, 3), 2 => 4, gelu; pad=SamePad()), + Conv((3, 3), 4 => 2, gelu; pad=SamePad()), + FlattenLayer(), Dense(18 => 2) + ), + Chain(Dense(2, 4, gelu), Dense(4, 2)) + ) Xs = (aType(randn(rng, Float32, 3, 3, 2, 4)), aType(randn(rng, Float32, 2, 4))) - for (model, X) in zip(models, Xs) + for (i, (model, X)) in enumerate(zip(models, Xs)) ps, st = Lux.setup(rng, model) |> dev smodel = StatefulLuxLayer{true}(model, ps, st) @@ -18,7 +22,20 @@ ForwardDiff.jacobian(smodel, X) end - @testset "$(backend)" for backend in (AutoZygote(), AutoForwardDiff()) + @testset for backend in ( + AutoZygote(), AutoForwardDiff(), + AutoEnzyme(; + mode=Enzyme.Forward, function_annotation=Enzyme.Const + ), + AutoEnzyme(; + mode=Enzyme.set_runtime_activity(Enzyme.Reverse), + function_annotation=Enzyme.Const + ) + ) + # Forward rules for Enzyme is currently not implemented for several Ops + i == 1 && backend isa AutoEnzyme && + ADTypes.mode(backend) isa ADTypes.ForwardMode && continue + J2 = allow_unstable() do batched_jacobian(smodel, backend, X) end @@ -40,7 +57,14 @@ end @testset "Issue #636 Chunksize Specialization" begin - for N in (2, 4, 8, 11, 12, 50, 51), backend in (AutoZygote(), AutoForwardDiff()) + for N in (2, 4, 8, 11, 12, 50, 51), + backend in ( + AutoZygote(), AutoForwardDiff(), AutoEnzyme(), + AutoEnzyme(; mode=Enzyme.Reverse) + ) + + ongpu && backend isa AutoEnzyme && continue + model = @compact(; potential=Dense(N => N, gelu), backend=backend) do x @return allow_unstable() do batched_jacobian(potential, backend, x) @@ -78,6 +102,13 @@ end @test Jx_zygote ≈ Jx_true + if !ongpu + Jx_enzyme = allow_unstable() do + batched_jacobian(ftest, AutoEnzyme(), x) + end + @test Jx_enzyme ≈ Jx_true + end + fincorrect(x) = x[:, 1] x = reshape(Float32.(1:6), 2, 3) |> dev diff --git a/test/autodiff/nested_autodiff_tests.jl b/test/autodiff/nested_autodiff_tests.jl index 6ebe189c4..e328df283 100644 --- a/test/autodiff/nested_autodiff_tests.jl +++ b/test/autodiff/nested_autodiff_tests.jl @@ -267,10 +267,8 @@ end end @test_gradients(__f, x, - ps; - atol=1.0f-3, - rtol=1.0f-3, - skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + ps; atol=1.0f-3, + rtol=1.0f-3, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end end end @@ -409,6 +407,5 @@ end end @test_gradients(__f, x, ps; atol=1.0f-3, - rtol=1.0f-3, - skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) + rtol=1.0f-3, skip_backends=[AutoReverseDiff(), AutoTracker(), AutoEnzyme()]) end From 3ed672d8cfa19e234f3dd14e0447107e6b881b46 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 11 Jan 2025 18:56:25 -0500 Subject: [PATCH 6/6] feat: initial support for reactant --- ext/LuxEnzymeExt/batched_autodiff.jl | 24 ++++++++++++++---------- ext/LuxReactantExt/patches.jl | 3 +++ src/utils.jl | 2 ++ 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/ext/LuxEnzymeExt/batched_autodiff.jl b/ext/LuxEnzymeExt/batched_autodiff.jl index e73dcad2a..7000d35ee 100644 --- a/ext/LuxEnzymeExt/batched_autodiff.jl +++ b/ext/LuxEnzymeExt/batched_autodiff.jl @@ -20,16 +20,17 @@ function batched_enzyme_jacobian_impl( J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]), prod(size(x)[1:(end - 1)]), B) - chunk_size = min(8, length(y) ÷ B) + chunk_size = Utils.max_enzyme_batched_chunk_size(y) partials = ntuple(_ -> zero(x), chunk_size) for i in 1:chunk_size:(length(x) ÷ B) idxs = i:min(i + chunk_size - 1, length(x) ÷ B) partials′ = make_onehot!(partials, idxs) J_partials = only(Enzyme.autodiff( - ad.mode, f, BatchDuplicated(x, partials′), Const.(args)...)) + ad.mode, f, make_batch_duplicated(x, partials′), Const.(args)... + )) for (idx, J_partial) in zip(idxs, J_partials) - copyto!(view(J, :, idx, :), reshape(J_partial, :, B)) + J[:, :, idx] .= reshape(J_partial, :, B) end end @@ -51,7 +52,7 @@ function batched_enzyme_jacobian_impl( J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]), prod(size(x)[1:(end - 1)]), B) - chunk_size = min(8, length(x) ÷ B) + chunk_size = Utils.max_enzyme_batched_chunk_size(y) partials = ntuple(_ -> zero(y), chunk_size) J_partials = ntuple(_ -> zero(x), chunk_size) @@ -61,11 +62,11 @@ function batched_enzyme_jacobian_impl( partials′ = make_onehot!(partials, idxs) J_partials′ = make_zero!(J_partials, idxs) Enzyme.autodiff( - ad.mode, fn, BatchDuplicated(y, partials′), - BatchDuplicated(x, J_partials′), Const.(args)... + ad.mode, fn, make_batch_duplicated(y, partials′), + make_batch_duplicated(x, J_partials′), Const.(args)... ) for (idx, J_partial) in zip(idxs, J_partials) - copyto!(view(J, idx, :, :), reshape(J_partial, :, B)) + J[idx, :, :] .= reshape(J_partial, :, B) end end @@ -74,16 +75,19 @@ end function make_onehot!(partials, idxs) for (idx, partial) in zip(idxs, partials) + partial .= false partial′ = reshape(partial, :, size(partial, ndims(partial))) - fill!(partial′, false) - fill!(view(partial′, idx, :), true) + partial′[idx, :] .= true end return partials[1:length(idxs)] end function make_zero!(partials, idxs) for partial in partials - fill!(partial, false) + partial .= false end return partials[1:length(idxs)] end + +make_batch_duplicated(x, dxs) = BatchDuplicated(x, dxs) +make_batch_duplicated(x, dx::Tuple{X}) where {X} = Duplicated(x, only(dx)) diff --git a/ext/LuxReactantExt/patches.jl b/ext/LuxReactantExt/patches.jl index 6d79f2b60..8f6d4f511 100644 --- a/ext/LuxReactantExt/patches.jl +++ b/ext/LuxReactantExt/patches.jl @@ -1,5 +1,8 @@ Utils.vec(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_array(vec(x)) +# XXX: remove once EnzymeJAX supports batched AD +Utils.max_enzyme_batched_chunk_size(x::AnyTracedRArray) = 1 + # XXX: Use PoolDims once EnzymeJAX supports stablehlo.reduce_window adjoint Lux.calculate_pool_dims(g::Lux.GlobalPoolMode, ::TracedRArray) = g diff --git a/src/utils.jl b/src/utils.jl index 99429f5c1..7ccdcc6cb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -228,6 +228,8 @@ end calculate_gain(::typeof(NNlib.leakyrelu), x) = typeof(x)(√(2 / (1 + x^2))) calculate_gain(::typeof(NNlib.selu), _) = 3.0f0 / 4 +max_enzyme_batched_chunk_size(x::AbstractArray) = min(8, length(x) ÷ Base.size(x, ndims(x))) + end using .Utils: Utils, BoolType, IntegerType, SymbolType, make_abstract_matrix,