Skip to content

Commit

Permalink
tweak
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Nov 23, 2024
1 parent 50c0c92 commit 48dc717
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
6 changes: 5 additions & 1 deletion ext/FluxEnzymeExt/FluxEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import Optimisers
import Functors
import Enzyme
using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal
using Enzyme: autodiff_thunk, ReverseSplitWithPrimal
using Enzyme: autodiff_thunk, Reverse, ReverseSplitWithPrimal
using ProgressLogging: @withprogress, @logprogress

EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true
Expand Down Expand Up @@ -42,12 +42,16 @@ function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::B
_check_mutable(x)
end

# Take I, doesn't allow for aux at all.
# _, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...)

# Take II, using split mode.
forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...)
tape, result, shadow_result = forward(Const(f), args...)
reverse(Const(f), args..., _sensitivity(result), tape)

# Take III, it may be more efficient to have the function write the loss into Ref(0.0), seed Duplicated(that, Ref(1.0))?

(; val = result, grad = map(_grad_or_nothing, args))
end

Expand Down
2 changes: 2 additions & 0 deletions test/ext_enzyme/enzyme.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,4 +195,6 @@ end
@test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1, Active(3f0))
@test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1.val, Active(3f0))
@test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, Const(m1.val), Active(3f0))
# Duplicated
@test_throws Exception Flux.gradient((m,z) -> sum(m.bias)/z, m1, Duplicated(3f0, 0f0))
end

0 comments on commit 48dc717

Please sign in to comment.