Skip to content

Commit

Permalink
feat: add reverse mode batched enzyme jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 25, 2024
1 parent 000b691 commit 91491c9
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 10 deletions.
5 changes: 3 additions & 2 deletions ext/LuxEnzymeExt/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module LuxEnzymeExt

using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode
using ArgCheck: @argcheck
using ConcreteStructs: @concrete
using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated
using EnzymeCore: EnzymeCore
using Functors: fmap
Expand All @@ -15,8 +16,8 @@ using MLDataDevices: isleaf
Lux.is_extension_loaded(::Val{:Enzyme}) = true

normalize_backend(::StaticBool, ad::AutoEnzyme) = ad
normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Forward)
normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode = Enzyme.Reverse)
normalize_backend(::True, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Forward)
normalize_backend(::False, ad::AutoEnzyme{Nothing}) = @set(ad.mode=Enzyme.Reverse)

annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f
annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f)
Expand Down
58 changes: 52 additions & 6 deletions ext/LuxEnzymeExt/batched_autodiff.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
function Lux.AutoDiffInternalImpl.batched_jacobian_impl(
f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
backend = normalize_backend(True(), ad)
return batched_enzyme_jacobian_impl(
annotate_function(ad, f), backend, ADTypes.mode(backend), x)
return batched_enzyme_jacobian_impl(f, backend, ADTypes.mode(backend), x)
end

function batched_enzyme_jacobian_impl(
f::F, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {F}
f_orig::G, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {G}
# We need to run the function once to get the output type. Can we use ForwardWithPrimal?
y = f(x)
y = f_orig(x)
f = annotate_function(ad, f_orig)

@argcheck y isa AbstractArray MethodError
if ndims(y) 1 || size(y, ndims(y)) != size(x, ndims(x))
Expand Down Expand Up @@ -36,8 +36,38 @@ function batched_enzyme_jacobian_impl(
end

function batched_enzyme_jacobian_impl(
f::F, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {F}
error("reverse mode is not supported yet")
f_orig::G, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {G}
# We need to run the function once to get the output type. Can we use ReverseWithPrimal?
y = f_orig(x)

@argcheck y isa AbstractArray MethodError
if ndims(y) 1 || size(y, ndims(y)) != size(x, ndims(x))
throw(AssertionError("`batched_jacobian` only supports batched outputs \
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))
end
B = size(y, ndims(y))

J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]),
prod(size(x)[1:(end - 1)]), B)

chunk_size = min(8, length(x) ÷ B)
partials = ntuple(_ -> zero(y), chunk_size)
J_partials = ntuple(_ -> zero(x), chunk_size)

fn = annotate_function(ad, OOPFunctionWrapper(f_orig))
for i in 1:chunk_size:(length(y) ÷ B)
idxs = i:min(i + chunk_size - 1, length(y) ÷ B)
partials′ = make_onehot!(partials, idxs)
J_partials′ = make_zero!(J_partials, idxs)
Enzyme.autodiff(
ad.mode, fn, BatchDuplicated(y, partials′), BatchDuplicated(x, J_partials′)
)
for (idx, J_partial) in zip(idxs, J_partials)
copyto!(view(J, idx, :, :), reshape(J_partial, :, B))
end
end

return J
end

function make_onehot!(partials, idxs)
Expand All @@ -48,3 +78,19 @@ function make_onehot!(partials, idxs)
end
return partials[1:length(idxs)]
end

function make_zero!(partials, idxs)
for partial in partials
fill!(partial, false)
end
return partials[1:length(idxs)]
end

@concrete struct OOPFunctionWrapper
f
end

function (f::OOPFunctionWrapper)(y, x)
copyto!(y, f.f(x))
return
end
2 changes: 1 addition & 1 deletion src/autodiff/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ the following properties for `y = f(x)`:
## Backends & AD Packages
| Supported Backends | Packages Needed | Note |
| Supported Backends | Packages Needed | Notes |
|:------------------ |:--------------- |:---------------------------------------------- |
| `AutoForwardDiff` | | |
| `AutoZygote` | `Zygote.jl` | |
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ using Lux
@test_throws ErrorException vector_jacobian_product(
x -> x, AutoZygote(), rand(2), rand(2))

@test_throws ArgumentError batched_jacobian(x -> x, AutoEnzyme(), rand(2, 2))
@test_throws ArgumentError batched_jacobian(x -> x, AutoTracker(), rand(2, 2))
@test_throws ErrorException batched_jacobian(x -> x, AutoZygote(), rand(2, 2))
end

Expand Down

0 comments on commit 91491c9

Please sign in to comment.