Skip to content

Commit

Permalink
fixup
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Sep 17, 2024
1 parent b6f6178 commit 5bd4dbf
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions ext/QuadGKEnzymeExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,16 +89,19 @@ struct MixedClosureVector{F}
end

function Base.:+(a::CV, b::CV) where {CV <: MixedClosureVector}
Enzyme.Compiler.recursive_accumulate(a, b, identity)::CV
res = deepcopy(a)::CV
Enzyme.Compiler.recursive_accumulate(res, b, identity)::CV

Check warning on line 93 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L91-L93

Added lines #L91 - L93 were not covered by tests
end

function Base.:-(a::CV, b::CV) where {CV <: MixedClosureVector}
Enzyme.Compiler.recursive_accumulate(a, b, x->-x)::CV
res = deepcopy(a)::CV
Enzyme.Compiler.recursive_accumulate(res, b, x->-x)::CV

Check warning on line 98 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L96-L98

Added lines #L96 - L98 were not covered by tests
end

function Base.:*(a::Number, b::CV) where {CV <: MixedClosureVector}

Check warning on line 101 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L101

Added line #L101 was not covered by tests
# b + (a-1) * b = a * b
Enzyme.Compiler.recursive_accumulate(b, b, x->(a-1)*x)::CV
res = deepcopy(b)::CV
Enzyme.Compiler.recursive_accumulate(res, b, x->(a-1)*x)::CV

Check warning on line 104 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L103-L104

Added lines #L103 - L104 were not covered by tests
end

function Base.:*(a::MixedClosureVector, b::Number)
Expand Down Expand Up @@ -126,7 +129,8 @@ function Enzyme.EnzymeRules.reverse(config, ofunc::Const{typeof(quadgk)}, dres::
drev = rev(Const(call), f, Const(x), dres.val[1], tape)
return MixedClosureVector(fshadow)

Check warning on line 130 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L123-L130

Added lines #L123 - L130 were not covered by tests
end
_df.f[]
Enzyme.Compiler.recursive_accumulate(f.dval, _df.f)
nothing

Check warning on line 133 in ext/QuadGKEnzymeExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/QuadGKEnzymeExt.jl#L132-L133

Added lines #L132 - L133 were not covered by tests
end
dsegs1 = segs[1] isa Const ? nothing : -LinearAlgebra.dot(f.val(segs[1].val), dres.val[1])
dsegsn = segs[end] isa Const ? nothing : LinearAlgebra.dot(f.val(segs[end].val), dres.val[1])
Expand Down

0 comments on commit 5bd4dbf

Please sign in to comment.