diff --git a/src/FFTW.jl b/src/FFTW.jl index 4366ee7..0493cbc 100644 --- a/src/FFTW.jl +++ b/src/FFTW.jl @@ -12,7 +12,7 @@ import AbstractFFTs: Plan, ScaledPlan, rfft_output_size, brfft_output_size, plan_inv, normalization -export dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct! +export dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct!, brfft!, irfft!, plan_brfft!, plan_irfft! include("providers.jl") diff --git a/src/fft.jl b/src/fft.jl index 4063ea7..80f03e2 100644 --- a/src/fft.jl +++ b/src/fft.jl @@ -1,8 +1,10 @@ # This file was formerly a part of Julia. License is MIT: https://julialang.org/license -import Base: show, *, convert, unsafe_convert, size, strides, ndims, pointer +import Base: show, *, convert, unsafe_convert, size, strides, ndims, pointer, copy, getindex import LinearAlgebra: mul! + + """ r2r(A, kind [, dims]) @@ -108,6 +110,32 @@ const fftwSingle = Union{Float32,Complex{Float32}} const fftwTypeDouble = Union{Type{Float64},Type{Complex{Float64}}} const fftwTypeSingle = Union{Type{Float32},Type{Complex{Float32}}} +# padded array type to support RFFT in-place operations +struct PaddedRFFTArray{T, N, Ac, Ar} <: AbstractArray{T, N} where {Ac <: AbstractArray{Complex{T}, N}, Ar <: AbstractArray{T, N}} + c::Ac # complex view / underlying array + r::Ar # real view + + # wrap existing complex array + function PaddedRFFTArray(a::StridedArray{Complex{T}, N}) where {T, N} + fsize = 2 * (size(a, 1) - 1) + r = view(reinterpret(T, a), Base.OneTo(fsize), ntuple(i -> Colon(), Val(ndims(a) - 1))...) + new{T, N, typeof(a), typeof(r)}(a, r) + end + + # copy existing padded array + function PaddedRFFTArray(a::PaddedRFFTArray{T, N, Ac, Ar}) where {T, N, Ac, Ar} + c = copy(complex_view(a)) + r = view(reinterpret(T, c), Base.OneTo(size(real_view(a), 1)), ntuple(i -> Colon(), Val(ndims(a) - 1))...) + new{T, N, Ac, Ar}(c, r) + end +end + +@inline real_view(S::PaddedRFFTArray) = S.r +@inline complex_view(S::PaddedRFFTArray) = S.c +size(a::PaddedRFFTArray) = size(complex_view(a)) +getindex(a::PaddedRFFTArray, args...) = getindex(complex_view(a), args...) +copy(a::PaddedRFFTArray) = PaddedRFFTArray(a) + # For ESTIMATE plans, FFTW allows one to pass NULL for the array pointer, # since it is not written to. Hence, it is convenient to create an # array-like type that carries a size and a stride like a "real" array @@ -728,7 +756,18 @@ function *(p::cFFTWPlan{T,K,true}, x::StridedArray{T}) where {T,K} return x end -# rfft/brfft and planned variants. No in-place version for now. +# rfft/brfft and planned variants. No in-place version of rfft for now. + +for f in (:brfft!, :irfft!) + pf = Symbol("plan_", f) + @eval begin + $f(x::AbstractArray, d::Integer) = $f(PaddedRFFTArray(x), d) + $f(x::AbstractArray, d::Integer, region) = $f(PaddedRFFTArray(x), d, region) + $f(x::PaddedRFFTArray, d::Integer) = $pf(x, d) * x + $f(x::PaddedRFFTArray, d::Integer, region) = $pf(x, d, region) * x + $pf(x::PaddedRFFTArray, d::Integer; kws...) = $pf(x, d, 1:ndims(x); kws...) + end +end for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) # Note: use $FORWARD and $BACKWARD below because of issue #9775 @@ -763,6 +802,21 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,1:ndims(X);kws...) plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,1:ndims(X);kws...) + function plan_brfft!(X::PaddedRFFTArray{$Tr,N}, d::Integer, region; + flags::Integer=ESTIMATE, + timelimit::Real=NO_TIMELIMIT) where N + @assert size(complex_view(X))[first(region)] == d>>1 + 1 + return rFFTWPlan{$Tc,$BACKWARD,true,N}(complex_view(X), real_view(X), region, flags, timelimit) + end + + plan_brfft!(X::PaddedRFFTArray{$Tr};kws...)=plan_brfft!(X,1:ndims(X);kws...) + + function plan_irfft!(x::PaddedRFFTArray{$Tr,N}, d::Integer, region; kws...) where N + ScaledPlan(plan_brfft!(x, d, region; kws...), normalization($Tr, size(real_view(x)), region)) + end + + plan_irfft!(X::PaddedRFFTArray{$Tr};kws...)=plan_irfft!(X,1:ndims(X);kws...) + function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N}) where N X = Array{$Tr}(undef, p.sz) Y = p.flags&ESTIMATE != 0 ? FakeArray{$Tc}(p.osz) : Array{$Tc}(undef, p.osz) @@ -812,9 +866,18 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64}))) end return y end + + *(p::rFFTWPlan{$Tc,$BACKWARD,true,N}, f::PaddedRFFTArray{$Tr,N}) where N = + (mul!(real_view(f), p, complex_view(f)); real_view(f)) end end +*(p::ScaledPlan, f::PaddedRFFTArray) = begin + p.p * f + rmul!(real_view(f), p.scale) + real_view(f) +end + # FFTW r2r transforms (low-level interface) for f in (:r2r, :r2r!) @@ -890,3 +953,4 @@ function *(p::r2rFFTWPlan{T,K,true}, x::StridedArray{T}) where {T,K} unsafe_execute!(p, x, x) return x end + diff --git a/test/runtests.jl b/test/runtests.jl index 2291003..8c5b95f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -528,3 +528,33 @@ end @test occursin("dft-thr", string(p2)) end end + +begin + @testset "PaddedRFFTArray creation" begin + a = rand(Complex{Float64}, (8, 4, 4)) + b = FFTW.PaddedRFFTArray(a) + c = copy(b) + @test a == FFTW.complex_view(b) + @test c == b + @test FFTW.real_view(c) == FFTW.real_view(b) + @test FFTW.complex_view(c) == FFTW.complex_view(b) + end + + @testset "irfft!" begin + a = rand(Float64, (8, 4, 4)) + c = rfft(a) + d = copy(c) + e = irfft!(d, size(a, 1)) + @test a ≈ e + @test irfft(c, size(a, 1)) ≈ e + @test d === parent(parent(e)) + @test d != c + @test irfft(c, size(a, 1), 1:2) ≈ irfft!(copy(c), size(a, 1), 1:2) + end + + @testset "brfft!" begin + a = rand(Float64,(4,4)) + b = rfft(a) + @test (brfft!(b, size(a, 1)) ./ 16) ≈ a + end +end