Skip to content

Commit

Permalink
feat: add forward mode batched enzyme jacobian
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 24, 2024
1 parent e5a6e7d commit 746d416
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 18 deletions.
22 changes: 19 additions & 3 deletions ext/LuxEnzymeExt/LuxEnzymeExt.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,30 @@
module LuxEnzymeExt

using ADTypes: AutoEnzyme
using Enzyme: Enzyme, Active, Const, Duplicated
using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode
using ArgCheck: @argcheck
using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated
using EnzymeCore: EnzymeCore
using Setfield: @set!
using Static: False, True
using Static: False, True, StaticBool

using Lux: Lux
using Lux.Training: TrainingBackendCache, TrainState

Lux.is_extension_loaded(::Val{:Enzyme}) = true

normalize_backend(::StaticBool, ad::AutoEnzyme) = ad
function normalize_backend(#=prefer_forward=#::True, ad::AutoEnzyme{Nothing, A}) where {A}
return AutoEnzyme(; mode=Enzyme.Forward, function_annotation=A)
end
function normalize_backend(#=prefer_forward=#::False, ad::AutoEnzyme{Nothing, A}) where {A}
return AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=A)
end

annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f
annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f)

include("training.jl")

include("batched_autodiff.jl")

end
50 changes: 50 additions & 0 deletions ext/LuxEnzymeExt/batched_autodiff.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
function Lux.AutoDiffInternalImpl.batched_jacobian_impl(
f::F, ad::AutoEnzyme, x::AbstractArray) where {F}
backend = normalize_backend(True(), ad)
return batched_enzyme_jacobian_impl(
annotate_function(ad, f), backend, ADTypes.mode(backend), x)
end

function batched_enzyme_jacobian_impl(
f::F, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {F}
# We need to run the function once to get the output type. Can we use ForwardWithPrimal?
y = f(x)

@argcheck y isa AbstractArray MethodError
if ndims(y) 1 || size(y, ndims(y)) != size(x, ndims(x))
throw(AssertionError("`batched_jacobian` only supports batched outputs \
(ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x))."))
end
B = size(y, ndims(y))

J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]),
prod(size(x)[1:(end - 1)]), B)

chunk_size = min(8, length(y) ÷ B)
partials = ntuple(_ -> zero(x), chunk_size)

for i in 1:chunk_size:(length(x) ÷ B)
idxs = i:min(i + chunk_size - 1, length(x) ÷ B)
partials′ = make_onehot!(partials, idxs)
J_partials = only(Enzyme.autodiff(ad.mode, f, BatchDuplicated(x, partials′)))
for (idx, J_partial) in zip(idxs, J_partials)
copyto!(view(J, :, idx, :), reshape(J_partial, :, B))
end
end

return J
end

function batched_enzyme_jacobian_impl(
f::F, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {F}
error("reverse mode is not supported yet")
end

function make_onehot!(partials, idxs)
for (idx, partial) in zip(idxs, partials)
partial′ = reshape(partial, :, size(partial, ndims(partial)))
fill!(partial′, false)
partial′[idx, :] .= true
end
return partials[1:length(idxs)]
end
1 change: 1 addition & 0 deletions src/Lux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ const NAME_TYPE = Union{Nothing, String, Symbol}
const Optional{T} = Union{T, Nothing}

is_extension_loaded(::Val) = false
is_extension_loaded(::Val{:ForwardDiff}) = true

# Preferences
include("preferences.jl")
Expand Down
38 changes: 23 additions & 15 deletions src/autodiff/api.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,7 @@ function vector_jacobian_product(::F, backend::AbstractADType, _, __) where {F}
end

function vector_jacobian_product(f::F, backend::AutoZygote, x, u) where {F}
if !is_extension_loaded(Val(:Zygote))
error("`Zygote.jl` must be loaded for `vector_jacobian_product` \
to work with `$(backend)`.")
end
assert_backend_loaded(:vector_jacobian_product, backend)
return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u)
end

Expand Down Expand Up @@ -89,10 +86,11 @@ the following properties for `y = f(x)`:
## Backends & AD Packages
| Supported Backends | Packages Needed |
|:------------------ |:--------------- |
| `AutoForwardDiff` | |
| `AutoZygote` | `Zygote.jl` |
| Supported Backends | Packages Needed | Note |
|:------------------ |:--------------- |:---------------------------------------------- |
| `AutoForwardDiff` | | |
| `AutoZygote` | `Zygote.jl` | |
| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD |
## Arguments
Expand All @@ -118,14 +116,24 @@ function batched_jacobian(::F, backend::AbstractADType, x::AbstractArray) where
throw(ArgumentError("`batched_jacobian` is not implemented for `$(backend)`."))
end

function batched_jacobian(f::F, backend::AutoForwardDiff, x::AbstractArray) where {F}
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
for implemented_backend in (AutoForwardDiff, AutoZygote, AutoEnzyme)
@eval function batched_jacobian(
f::F, backend::$(implemented_backend), x::AbstractArray) where {F}
assert_backend_loaded(:batched_jacobian, backend)
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
end
end

function batched_jacobian(f::F, backend::AutoZygote, x::AbstractArray) where {F}
if !is_extension_loaded(Val(:Zygote))
error("`Zygote.jl` must be loaded for `batched_jacobian` to work with \
`$(backend)`.")
function assert_backend_loaded(fname::Symbol, ad::AbstractADType)
return assert_backend_loaded(fname, ad, adtype_to_backend(ad))
end
function assert_backend_loaded(fname::Symbol, ad::AbstractADType, backend::Val{B}) where {B}
if !is_extension_loaded(backend)
error("$(fname) with `$(ad)` requires $(B).jl to be loaded.")
end
return AutoDiffInternalImpl.batched_jacobian(f, backend, x)
return
end

adtype_to_backend(::AutoEnzyme) = Val(:Enzyme)
adtype_to_backend(::AutoForwardDiff) = Val(:ForwardDiff)
adtype_to_backend(::AutoZygote) = Val(:Zygote)

0 comments on commit 746d416

Please sign in to comment.