From 48dc71730afacd92259f82d64ca72ca64cf311bd Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 22 Nov 2024 20:02:42 -0500 Subject: [PATCH] tweak --- ext/FluxEnzymeExt/FluxEnzymeExt.jl | 6 +++++- test/ext_enzyme/enzyme.jl | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/ext/FluxEnzymeExt/FluxEnzymeExt.jl b/ext/FluxEnzymeExt/FluxEnzymeExt.jl index e7ce025cd3..cf1bb9da3d 100644 --- a/ext/FluxEnzymeExt/FluxEnzymeExt.jl +++ b/ext/FluxEnzymeExt/FluxEnzymeExt.jl @@ -10,7 +10,7 @@ import Optimisers import Functors import Enzyme using Enzyme: EnzymeRules, Active, Const, Duplicated, autodiff, ReverseWithPrimal -using Enzyme: autodiff_thunk, ReverseSplitWithPrimal +using Enzyme: autodiff_thunk, Reverse, ReverseSplitWithPrimal using ProgressLogging: @withprogress, @logprogress EnzymeRules.inactive(::typeof(Flux.Losses._check_sizes), args...) = true @@ -42,12 +42,16 @@ function Flux._enzyme_withgradient(f, args::Union{Const, Duplicated}...; zero::B _check_mutable(x) end + # Take I, doesn't allow for aux at all. # _, val = Enzyme.autodiff(ReverseWithPrimal, f, Active, args...) + # Take II, using split mode. forward, reverse = autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Active, map(typeof, args)...) tape, result, shadow_result = forward(Const(f), args...) reverse(Const(f), args..., _sensitivity(result), tape) + # Take III, it may be more efficient to have the function write the loss into Ref(0.0), seed Duplicated(that, Ref(1.0))? + (; val = result, grad = map(_grad_or_nothing, args)) end diff --git a/test/ext_enzyme/enzyme.jl b/test/ext_enzyme/enzyme.jl index 7c3d88e1f5..6d95fb299d 100644 --- a/test/ext_enzyme/enzyme.jl +++ b/test/ext_enzyme/enzyme.jl @@ -195,4 +195,6 @@ end @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1, Active(3f0)) @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, m1.val, Active(3f0)) @test_throws ArgumentError Flux.gradient((m,z) -> sum(m.bias)/z, Const(m1.val), Active(3f0)) + # Duplicated + @test_throws Exception Flux.gradient((m,z) -> sum(m.bias)/z, m1, Duplicated(3f0, 0f0)) end