Skip to content

Commit

Permalink
dct, dct! and r2r!
Browse files Browse the repository at this point in the history
  • Loading branch information
dlfivefifty committed Dec 19, 2024
1 parent c182bc6 commit d7b17e1
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 18 deletions.
35 changes: 30 additions & 5 deletions ext/FFTWForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,40 @@
module FFTWForwardDiffExt
using FFTW
using ForwardDiff
import FFTW: plan_r2r, r2r
import FFTW: plan_r2r, plan_r2r!, plan_dct, plan_dct!, plan_idct, plan_idct!, r2r, r2r!, dct, dct!, idct, idct!, fftwReal, REDFT10, REDFT01
import FFTW.AbstractFFTs: dualplan, dual2array
import ForwardDiff: Dual

plan_r2r(x::AbstractArray{D}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, plan_r2r(dual2array(x), FLAG, 1 .+ dims))
plan_r2r(x::AbstractArray{<:Complex{D}}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, plan_r2r(dual2array(x), FLAG, 1 .+ dims))

for plan in (:plan_r2r, :plan_r2r!)
@eval begin
$plan(x::AbstractArray{D}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, $plan(dual2array(x), FLAG, 1 .+ dims))
$plan(x::AbstractArray{<:Complex{D}}, FLAG, dims=1:ndims(x)) where D<:Dual = dualplan(D, $plan(dual2array(x), FLAG, 1 .+ dims))
end
end

r2r(x::AbstractArray{<:Dual}, kinds, region...) = plan_r2r(x, kinds, region...) * x
r2r(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = plan_r2r(x, kinds, region...) * x
for f in (:r2r, :r2r!)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray{<:Dual}, kinds, region...) = $pf(x, kinds, region...) * x
$f(x::AbstractArray{<:Complex{<:Dual}}, kinds, region...) = $pf(x, kinds, region...) * x
end
end


for f in (:dct, :dct!, :idct, :idct!)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray{<:Dual}) = $pf(x) * x
$f(x::AbstractArray{<:Dual}, region) = $pf(x, region) * x
end
end

for plan in (:plan_dct, :plan_dct!, :plan_idct, :plan_idct!)
@eval begin
$plan(x::AbstractArray{D}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, $plan(dual2array(x), 1 .+ dims; kwds...))
$plan(x::AbstractArray{<:Complex{D}}, dims=1:ndims(x); kwds...) where D<:Dual = dualplan(D, $plan(dual2array(x), 1 .+ dims; kwds...))
end
end

end #module
43 changes: 30 additions & 13 deletions test/fftwforwarddiff.jl
Original file line number Diff line number Diff line change
@@ -1,21 +1,38 @@
using FFTW, ForwardDiff, Test
using ForwardDiff: Dual, value, partials

@testset "r2r" begin
x1 = Dual.(1:4.0, 2:5, 3:6)
t = FFTW.r2r(x1, FFTW.R2HC)
@testset "ForwardDiff extension" begin
@testset "r2r" begin
x1 = Dual.(1:4.0, 2:5, 3:6)
t = FFTW.r2r(x1, FFTW.R2HC)

@test value.(t) == FFTW.r2r(value.(x1), FFTW.R2HC)
@test partials.(t, 1) == FFTW.r2r(partials.(x1, 1), FFTW.R2HC)
@test partials.(t, 2) == FFTW.r2r(partials.(x1, 2), FFTW.R2HC)
@test value.(t) == FFTW.r2r(value.(x1), FFTW.R2HC)
@test partials.(t, 1) == FFTW.r2r(partials.(x1, 1), FFTW.R2HC)
@test partials.(t, 2) == FFTW.r2r(partials.(x1, 2), FFTW.R2HC)

t = FFTW.r2r(x1 + 2im*x1, FFTW.R2HC)
@test value.(t) == FFTW.r2r(value.(x1 + 2im*x1), FFTW.R2HC)
@test partials.(t, 1) == FFTW.r2r(partials.(x1 + 2im*x1, 1), FFTW.R2HC)
@test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC)
t = FFTW.r2r(x1 + 2im*x1, FFTW.R2HC)
@test value.(t) == FFTW.r2r(value.(x1 + 2im*x1), FFTW.R2HC)
@test partials.(t, 1) == FFTW.r2r(partials.(x1 + 2im*x1, 1), FFTW.R2HC)
@test partials.(t, 2) == FFTW.r2r(partials.(x1 + 2im*x1, 2), FFTW.R2HC)

f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1]
@test ForwardDiff.derivative(f, 0.1) 1.0
f = ω -> FFTW.r2r([ω; zeros(9)], FFTW.R2HC)[1]
@test ForwardDiff.derivative(f, 0.1) 1.0

@test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC)
@test mul!(similar(x1), FFTW.plan_r2r(x1, FFTW.R2HC), x1) == FFTW.r2r(x1, FFTW.R2HC)

x = [Dual(1.0,2,3), Dual(4,5,6)]
a = FFTW.r2r(x, FFTW.REDFT00)
b = FFTW.r2r!(x, FFTW.REDFT00)
@test a == b == x
end

@testset "dct" begin
x = [Dual(1.0,2,3), Dual(4,5,6)]
a = dct(x)
b = dct!(x)
@test a == b == x

c = x -> dct([x; 0; 0])[1]
@test ForwardDiff.derivative(c,0.1) 1/sqrt(3)
end
end

0 comments on commit d7b17e1

Please sign in to comment.