diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 62fdcae..b62607d 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -22,7 +22,7 @@ jobs: fail-fast: false matrix: version: - - '1.2' + - '1.9' - '1' # - 'nightly' os: diff --git a/Project.toml b/Project.toml index 2c44fbd..1b6a6eb 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,12 @@ version = "2.9.4" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +[weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" + +[extensions] +QuadGKEnzymeExt = "Enzyme" + [compat] DataStructures = "0.11, 0.12, 0.13, 0.14, 0.15, 0.16, 0.17, 0.18, 0.19" julia = "1.2" diff --git a/ext/QuadGKEnzymeExt.jl b/ext/QuadGKEnzymeExt.jl new file mode 100644 index 0000000..d10d21a --- /dev/null +++ b/ext/QuadGKEnzymeExt.jl @@ -0,0 +1,132 @@ + +module QuadGKEnzymeExt + +using QuadGK, Enzyme, LinearAlgebra + +function Enzyme.EnzymeRules.augmented_primal(config, ofunc::Const{typeof(quadgk)}, ::Type{RT}, f, segs::Annotation{T}...; kws...) where {RT, T} + prims = map(x->x.val, segs) + + retres, segbuf = if f isa Const + if EnzymeRules.needs_primal(config) + quadgk(f.val, prims...; kws...), nothing + else + nothing + end + else + I, E, segbuf = quadgk_segbuf(f.val, prims...; kws...) + if EnzymeRules.needs_primal(config) + (I, E), segbuf + else + nothing, segbuf + end + end + + dres = if !Enzyme.EnzymeRules.needs_shadow(config) + nothing + elseif EnzymeRules.width(config) == 1 + zero.(res...) + else + ntuple(Val(EnzymeRules.width(config))) do i + Base.@_inline_meta + zero.(res...) + end + end + + cache = if RT <: Duplicated || RT <: DuplicatedNoNeed || RT <: BatchDuplicated || RT <: BatchDuplicatedNoNeed + dres + else + nothing + end + cache2 = segbuf, cache + + return Enzyme.EnzymeRules.AugmentedReturn{ + Enzyme.EnzymeRules.needs_primal(config) ? eltype(RT) : Nothing, + Enzyme.EnzymeRules.needs_shadow(config) ? (Enzyme.EnzymeRules.width(config) == 1 ? eltype(RT) : NTuple{Enzyme.EnzymeRules.width(config), eltype(RT)}) : Nothing, + typeof(cache2) + }(retres, dres, cache2) +end + +function call(f, x) + f(x) +end + +# Wrapper around a function f that allows it to act as a vector space, and hence be usable as +# an integrand, where the vector operations act on the closed-over parameters of f that are +# begin differentiated with respect to. In particular, if we have a closure f = x -> g(x, p), and we want +# to differentiate with respect to p, then our reverse (vJp) rule needs an integrand given by the +# Jacobian-vector product (pullback) vᵀ∂g/∂p. But Enzyme wraps this in a closure so that it is the +# same "shape" as f, whereas to integrate it we need to be able to treat it as a vector space. +# ClosureVector calls Enzyme.Compiler.recursive_add, which is an internal function that "unwraps" +# the closure to access the internal state, which can then be added/subtracted/scaled. +struct ClosureVector{F} + f::F +end + +@inline function guaranteed_nonactive(::Type{T}) where T + rt = Enzyme.Compiler.active_reg_inner(T, (), nothing) + return rt == Enzyme.Compiler.AnyState || rt == Enzyme.Compiler.DupState +end + +function Base.:+(a::CV, b::CV) where {CV <: ClosureVector} + Enzyme.Compiler.recursive_add(a, b, identity, guaranteed_nonactive)::CV +end + +function Base.:-(a::CV, b::CV) where {CV <: ClosureVector} + Enzyme.Compiler.recursive_add(a, b, x->-x, guaranteed_nonactive)::CV +end + +function Base.:*(a::Number, b::CV) where {CV <: ClosureVector} + # b + (a-1) * b = a * b + Enzyme.Compiler.recursive_add(b, b, x->(a-1)*x, guaranteed_nonactive)::CV +end + +function Base.:*(a::ClosureVector, b::Number) + return b*a +end + +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Active, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T} + df = if f isa Const + nothing + else + segbuf = cache[1] + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) + _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x + tape, prim, shad = fwd(Const(call), f, Const(x)) + drev = rev(Const(call), f, Const(x), dres.val[1], tape) + return ClosureVector(drev[1][1]) + end + _df.f + end + dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1]) + dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1]) + return (df, # f + dsegs1, + ntuple(i -> nothing, Val(length(segs)-2))..., + dsegsn) +end + +function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::Type{<:Union{Duplicated, BatchDuplicated}}, cache, f::Union{Const, Active}, segs::Annotation{T}...; kws...) where {T} + dres = cache[2] + df = if f isa Const + nothing + else + segbuf = cache[1] + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitNoPrimal, Const{typeof(call)}, Active, typeof(f), Const{T}) + _df, _ = quadgk(map(x->x.val, segs)...; kws..., eval_segbuf=segbuf, maxevals=0, norm=f->0) do x + tape, prim, shad = fwd(Const(call), f, Const(x)) + shad .= dres + drev = rev(Const(call), f, Const(x), tape) + return ClosureVector(drev[1][1]) + end + _df.f + end + dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres) + dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres) + Enzyme.make_zero!(dres) + return (df, # f + dsegs1, + ntuple(i -> nothing, Val(length(segs)-2))..., + dsegsn) +end + +end # module diff --git a/src/api.jl b/src/api.jl index 6087bbc..b375260 100644 --- a/src/api.jl +++ b/src/api.jl @@ -132,6 +132,15 @@ function quadgk!(f!, result, a::T,b::T,c::T...; atol=nothing, rtol=nothing, maxe return quadgk(f, a, b, c...; atol=atol, rtol=rtol, maxevals=maxevals, order=order, norm=norm, segbuf=segbuf, eval_segbuf=eval_segbuf) end +struct Counter{F} + f::F + count::Base.RefValue{Int} +end +function (c::Counter{F})(args...) where F + c.count[] += 1 + c.f(args...) +end + """ quadgk_count(f, args...; kws...) @@ -146,12 +155,9 @@ it may be possible to mathematically transform the problem in some way to improve the convergence rate. """ function quadgk_count(f, args...; kws...) - count = 0 - i = quadgk(args...; kws...) do x - count += 1 - f(x) - end - return (i..., count) + counter = Counter(f, Ref(0)) + i = quadgk(counter, args...; kws...) + return (i..., counter.count[]) end """ diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 0000000..481c28e --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,4 @@ +[deps] +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" diff --git a/test/runtests.jl b/test/runtests.jl index 4e156f2..dae5b9c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -426,3 +426,34 @@ quadgk_segbuf_printnull(args...; kws...) = quadgk_segbuf_print(devnull, args...; @inferred QuadGK.to_segbuf([0,1]) @inferred QuadGK.to_segbuf([(0,1+3im)]) end + +# Extension package only supported in 1.9+ +@static if VERSION >= v"1.9" + using Enzyme + f1(x) = quadgk(cos, 0., x)[1] + f2(x) = quadgk(cos, x, 1)[1] + f3(x) = quadgk(y->cos(x * y), 0., 1.)[1] + + f1_count(x) = quadgk_count(cos, 0., x)[1] + f2_count(x) = quadgk_count(cos, x, 1)[1] + f3_count(x) = quadgk_count(y->cos(x * y), 0., 1.)[1] + + f_vec(x) = sum(quadgk(y->[cos(x[1] * y), cos(x[2] * y)], 0., 1.)[1]) + + @testset "Enzyme" begin + @test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1, Active(0.3))[1][1] + @test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2, Active(0.3))[1][1] + @test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3, Active(0.3))[1][1] + + @test cos(0.3) ≈ Enzyme.autodiff(Reverse, f1_count, Active(0.3))[1][1] + @test -cos(0.3) ≈ Enzyme.autodiff(Reverse, f2_count, Active(0.3))[1][1] + @test (0.3 * cos(0.3) - sin(0.3))/(0.3*0.3) ≈ Enzyme.autodiff(Reverse, f3_count, Active(0.3))[1][1] + + x = [0.3, 0.7] + dx = [0.0, 0.0] + f_vec(x) + # TODO custom rule with mixed vector returns not yet supported x/ref https://github.com/EnzymeAD/Enzyme.jl/issues/1692 + @test_throws Enzyme.Compiler.EnzymeRuntimeException autodiff(Reverse, f_vec, Duplicated(x, dx)) + # @test dx ≈ [(0.3 * cos(0.3) - sin(0.3))/(0.3*0.3), (0.7 * cos(0.7) - sin(0.7))/(0.7*0.7)] + end +end