diff --git a/Project.toml b/Project.toml index ab444f7179..fcaa5ed7c0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lux" uuid = "b2108857-7c20-44ae-9111-449ecde12c47" authors = ["Avik Pal and contributors"] -version = "1.0.6" +version = "1.1.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/LuxEnzymeExt/LuxEnzymeExt.jl b/ext/LuxEnzymeExt/LuxEnzymeExt.jl index 0174f39723..79f396bdfb 100644 --- a/ext/LuxEnzymeExt/LuxEnzymeExt.jl +++ b/ext/LuxEnzymeExt/LuxEnzymeExt.jl @@ -1,14 +1,30 @@ module LuxEnzymeExt -using ADTypes: AutoEnzyme -using Enzyme: Enzyme, Active, Const, Duplicated +using ADTypes: ADTypes, AutoEnzyme, ForwardMode, ReverseMode +using ArgCheck: @argcheck +using Enzyme: Enzyme, Active, Const, Duplicated, BatchDuplicated using EnzymeCore: EnzymeCore using Setfield: @set! -using Static: False, True +using Static: False, True, StaticBool using Lux: Lux using Lux.Training: TrainingBackendCache, TrainState +Lux.is_extension_loaded(::Val{:Enzyme}) = true + +normalize_backend(::StaticBool, ad::AutoEnzyme) = ad +function normalize_backend(#=prefer_forward=#::True, ad::AutoEnzyme{Nothing, A}) where {A} + return AutoEnzyme(; mode=Enzyme.Forward, function_annotation=A) +end +function normalize_backend(#=prefer_forward=#::False, ad::AutoEnzyme{Nothing, A}) where {A} + return AutoEnzyme(; mode=Enzyme.Reverse, function_annotation=A) +end + +annotate_function(::AutoEnzyme{<:Any, Nothing}, f::F) where {F} = f +annotate_function(::AutoEnzyme{<:Any, A}, f::F) where {F, A} = A(f) + include("training.jl") +include("batched_autodiff.jl") + end diff --git a/ext/LuxEnzymeExt/batched_autodiff.jl b/ext/LuxEnzymeExt/batched_autodiff.jl new file mode 100644 index 0000000000..4c3e2ada94 --- /dev/null +++ b/ext/LuxEnzymeExt/batched_autodiff.jl @@ -0,0 +1,50 @@ +function Lux.AutoDiffInternalImpl.batched_jacobian_impl( + f::F, ad::AutoEnzyme, x::AbstractArray) where {F} + backend = normalize_backend(True(), ad) + return batched_enzyme_jacobian_impl( + annotate_function(ad, f), backend, ADTypes.mode(backend), x) +end + +function batched_enzyme_jacobian_impl( + f::F, ad::AutoEnzyme, ::ForwardMode, x::AbstractArray) where {F} + # We need to run the function once to get the output type. Can we use ForwardWithPrimal? + y = f(x) + + @argcheck y isa AbstractArray MethodError + if ndims(y) ≤ 1 || size(y, ndims(y)) != size(x, ndims(x)) + throw(AssertionError("`batched_jacobian` only supports batched outputs \ + (ndims(y) > 1) && size(y, ndims(y)) == size(x, ndims(x)).")) + end + B = size(y, ndims(y)) + + J = similar(x, promote_type(eltype(y), eltype(x)), prod(size(y)[1:(end - 1)]), + prod(size(x)[1:(end - 1)]), B) + + chunk_size = min(8, length(y) ÷ B) + partials = ntuple(_ -> zero(x), chunk_size) + + for i in 1:chunk_size:(length(x) ÷ B) + idxs = i:min(i + chunk_size - 1, length(x) ÷ B) + partials′ = make_onehot!(partials, idxs) + J_partials = only(Enzyme.autodiff(ad.mode, f, BatchDuplicated(x, partials′))) + for (idx, J_partial) in zip(idxs, J_partials) + copyto!(view(J, :, idx, :), reshape(J_partial, :, B)) + end + end + + return J +end + +function batched_enzyme_jacobian_impl( + f::F, ad::AutoEnzyme, ::ReverseMode, x::AbstractArray) where {F} + error("reverse mode is not supported yet") +end + +function make_onehot!(partials, idxs) + for (idx, partial) in zip(idxs, partials) + partial′ = reshape(partial, :, size(partial, ndims(partial))) + fill!(partial′, false) + partial′[idx, :] .= true + end + return partials[1:length(idxs)] +end diff --git a/src/Lux.jl b/src/Lux.jl index 972b8aa419..16434a0a74 100644 --- a/src/Lux.jl +++ b/src/Lux.jl @@ -34,6 +34,7 @@ const NAME_TYPE = Union{Nothing, String, Symbol} const Optional{T} = Union{T, Nothing} is_extension_loaded(::Val) = false +is_extension_loaded(::Val{:ForwardDiff}) = true # Preferences include("preferences.jl") diff --git a/src/autodiff/api.jl b/src/autodiff/api.jl index 6db1e38da6..1077b6710a 100644 --- a/src/autodiff/api.jl +++ b/src/autodiff/api.jl @@ -33,10 +33,7 @@ function vector_jacobian_product(::F, backend::AbstractADType, _, __) where {F} end function vector_jacobian_product(f::F, backend::AutoZygote, x, u) where {F} - if !is_extension_loaded(Val(:Zygote)) - error("`Zygote.jl` must be loaded for `vector_jacobian_product` \ - to work with `$(backend)`.") - end + assert_backend_loaded(:vector_jacobian_product, backend) return AutoDiffInternalImpl.vector_jacobian_product(f, backend, x, u) end @@ -89,10 +86,11 @@ the following properties for `y = f(x)`: ## Backends & AD Packages -| Supported Backends | Packages Needed | -|:------------------ |:--------------- | -| `AutoForwardDiff` | | -| `AutoZygote` | `Zygote.jl` | +| Supported Backends | Packages Needed | Note | +|:------------------ |:--------------- |:---------------------------------------------- | +| `AutoForwardDiff` | | | +| `AutoZygote` | `Zygote.jl` | | +| `AutoEnzyme` | `Enzyme.jl` | Not compatible with ChainRules based Nested AD | ## Arguments @@ -118,14 +116,24 @@ function batched_jacobian(::F, backend::AbstractADType, x::AbstractArray) where throw(ArgumentError("`batched_jacobian` is not implemented for `$(backend)`.")) end -function batched_jacobian(f::F, backend::AutoForwardDiff, x::AbstractArray) where {F} - return AutoDiffInternalImpl.batched_jacobian(f, backend, x) +for implemented_backend in (AutoForwardDiff, AutoZygote, AutoEnzyme) + @eval function batched_jacobian( + f::F, backend::$(implemented_backend), x::AbstractArray) where {F} + assert_backend_loaded(:batched_jacobian, backend) + return AutoDiffInternalImpl.batched_jacobian(f, backend, x) + end end -function batched_jacobian(f::F, backend::AutoZygote, x::AbstractArray) where {F} - if !is_extension_loaded(Val(:Zygote)) - error("`Zygote.jl` must be loaded for `batched_jacobian` to work with \ - `$(backend)`.") +function assert_backend_loaded(fname::Symbol, ad::AbstractADType) + return assert_backend_loaded(fname, ad, adtype_to_backend(ad)) +end +function assert_backend_loaded(fname::Symbol, ad::AbstractADType, backend::Val{B}) where {B} + if !is_extension_loaded(backend) + error("$(fname) with `$(ad)` requires $(B).jl to be loaded.") end - return AutoDiffInternalImpl.batched_jacobian(f, backend, x) + return end + +adtype_to_backend(::AutoEnzyme) = Val(:Enzyme) +adtype_to_backend(::AutoForwardDiff) = Val(:ForwardDiff) +adtype_to_backend(::AutoZygote) = Val(:Zygote)