diff --git a/test/Project.toml b/test/Project.toml index e16f39a6c..d3f8a43a7 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -3,6 +3,7 @@ AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" diff --git a/test/runtests.jl b/test/runtests.jl index caf43cb91..8ce4443d7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -17,6 +17,7 @@ using Zygote: Zygote using ForwardDiff: ForwardDiff using ReverseDiff: ReverseDiff using FiniteDifferences: FiniteDifferences +using Enzyme: Enzyme using Compat: only using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils diff --git a/test/test_utils.jl b/test/test_utils.jl index 8367fbd6d..744d3b6c2 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -47,8 +47,9 @@ gradient(f, s::Symbol, args) = gradient(f, Val(s), args) function gradient(f, ::Val{:Zygote}, args) g = only(Zygote.gradient(f, args)) if isnothing(g) + # To respect the same output as other ADs if args isa AbstractArray{<:Real} - return zeros(size(args)) # To respect the same output as other ADs + return zeros(size(args)) else return zeros.(size.(args)) end @@ -57,6 +58,19 @@ function gradient(f, ::Val{:Zygote}, args) end end +function gradient(f, ::Val{:EnzymeForward}, args) + return Enzyme.gradient(Enzyme.Forward, f, args) +end + +function gradient(f, ::Val{:EnzymeReverse}, args) + # shape = size(args) + # f_prime(flatargs) = f(reshape(flatargs, shape...)) + # return Enzyme.gradient(Enzyme.Reverse, f_prime, reshape(args, prod(shape))) + d_args = zero(args) + Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active, Enzyme.Duplicated(args, d_args)) + return d_args +end + function gradient(f, ::Val{:ForwardDiff}, args) return ForwardDiff.gradient(f, args) end @@ -90,7 +104,10 @@ testdiagfunction(k::MOKernel, A) = sum(kernelmatrix_diag(k, A)) testdiagfunction(k::MOKernel, A, B) = sum(kernelmatrix_diag(k, A, B)) function test_ADs( - kernelfunction, args=nothing; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=[3, 3] + kernelfunction, + args=nothing; + ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse], + dims=[3, 3], ) test_fd = test_AD(:FiniteDiff, kernelfunction, args, dims) if !test_fd.anynonpass @@ -108,7 +125,9 @@ function check_zygote_type_stability(f, args...; ctx=Zygote.Context()) end function test_ADs( - k::MOKernel; ADs=[:Zygote, :ForwardDiff, :ReverseDiff], dims=(in=3, out=2, obs=3) + k::MOKernel; + ADs=[:Zygote, :ForwardDiff, :ReverseDiff, :EnzymeReverse], + dims=(in=3, out=2, obs=3), ) test_fd = test_FiniteDiff(k, dims) if !test_fd.anynonpass