Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In-place irfft/brfft #219

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/FFTW.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import AbstractFFTs: Plan, ScaledPlan,
rfft_output_size, brfft_output_size,
plan_inv, normalization

export dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct!
export dct, idct, dct!, idct!, plan_dct, plan_idct, plan_dct!, plan_idct!, brfft!, irfft!, plan_brfft!, plan_irfft!

include("providers.jl")

Expand Down
68 changes: 66 additions & 2 deletions src/fft.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# This file was formerly a part of Julia. License is MIT: https://julialang.org/license

import Base: show, *, convert, unsafe_convert, size, strides, ndims, pointer
import Base: show, *, convert, unsafe_convert, size, strides, ndims, pointer, copy, getindex
import LinearAlgebra: mul!



"""
r2r(A, kind [, dims])

Expand Down Expand Up @@ -108,6 +110,32 @@ const fftwSingle = Union{Float32,Complex{Float32}}
const fftwTypeDouble = Union{Type{Float64},Type{Complex{Float64}}}
const fftwTypeSingle = Union{Type{Float32},Type{Complex{Float32}}}

# padded array type to support RFFT in-place operations
struct PaddedRFFTArray{T, N, Ac, Ar} <: AbstractArray{T, N} where {Ac <: AbstractArray{Complex{T}, N}, Ar <: AbstractArray{T, N}}
c::Ac # complex view / underlying array
r::Ar # real view

# wrap existing complex array
function PaddedRFFTArray(a::StridedArray{Complex{T}, N}) where {T, N}
fsize = 2 * (size(a, 1) - 1)
r = view(reinterpret(T, a), Base.OneTo(fsize), ntuple(i -> Colon(), Val(ndims(a) - 1))...)
new{T, N, typeof(a), typeof(r)}(a, r)
end

# copy existing padded array
function PaddedRFFTArray(a::PaddedRFFTArray{T, N, Ac, Ar}) where {T, N, Ac, Ar}
c = copy(complex_view(a))
r = view(reinterpret(T, c), Base.OneTo(size(real_view(a), 1)), ntuple(i -> Colon(), Val(ndims(a) - 1))...)
new{T, N, Ac, Ar}(c, r)
end
end

@inline real_view(S::PaddedRFFTArray) = S.r
@inline complex_view(S::PaddedRFFTArray) = S.c
size(a::PaddedRFFTArray) = size(complex_view(a))
getindex(a::PaddedRFFTArray, args...) = getindex(complex_view(a), args...)
copy(a::PaddedRFFTArray) = PaddedRFFTArray(a)

# For ESTIMATE plans, FFTW allows one to pass NULL for the array pointer,
# since it is not written to. Hence, it is convenient to create an
# array-like type that carries a size and a stride like a "real" array
Expand Down Expand Up @@ -728,7 +756,18 @@ function *(p::cFFTWPlan{T,K,true}, x::StridedArray{T}) where {T,K}
return x
end

# rfft/brfft and planned variants. No in-place version for now.
# rfft/brfft and planned variants. No in-place version of rfft for now.

for f in (:brfft!, :irfft!)
pf = Symbol("plan_", f)
@eval begin
$f(x::AbstractArray, d::Integer) = $f(PaddedRFFTArray(x), d)
$f(x::AbstractArray, d::Integer, region) = $f(PaddedRFFTArray(x), d, region)
$f(x::PaddedRFFTArray, d::Integer) = $pf(x, d) * x
$f(x::PaddedRFFTArray, d::Integer, region) = $pf(x, d, region) * x
$pf(x::PaddedRFFTArray, d::Integer; kws...) = $pf(x, d, 1:ndims(x); kws...)
end
end

for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
# Note: use $FORWARD and $BACKWARD below because of issue #9775
Expand Down Expand Up @@ -763,6 +802,21 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,1:ndims(X);kws...)
plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,1:ndims(X);kws...)

function plan_brfft!(X::PaddedRFFTArray{$Tr,N}, d::Integer, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where N
@assert size(complex_view(X))[first(region)] == d>>1 + 1
return rFFTWPlan{$Tc,$BACKWARD,true,N}(complex_view(X), real_view(X), region, flags, timelimit)
end

plan_brfft!(X::PaddedRFFTArray{$Tr};kws...)=plan_brfft!(X,1:ndims(X);kws...)

function plan_irfft!(x::PaddedRFFTArray{$Tr,N}, d::Integer, region; kws...) where N
ScaledPlan(plan_brfft!(x, d, region; kws...), normalization($Tr, size(real_view(x)), region))
end

plan_irfft!(X::PaddedRFFTArray{$Tr};kws...)=plan_irfft!(X,1:ndims(X);kws...)

function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N}) where N
X = Array{$Tr}(undef, p.sz)
Y = p.flags&ESTIMATE != 0 ? FakeArray{$Tc}(p.osz) : Array{$Tc}(undef, p.osz)
Expand Down Expand Up @@ -812,9 +866,18 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
end
return y
end

*(p::rFFTWPlan{$Tc,$BACKWARD,true,N}, f::PaddedRFFTArray{$Tr,N}) where N =
(mul!(real_view(f), p, complex_view(f)); real_view(f))
end
end

*(p::ScaledPlan, f::PaddedRFFTArray) = begin
p.p * f
rmul!(real_view(f), p.scale)
real_view(f)
end

# FFTW r2r transforms (low-level interface)

for f in (:r2r, :r2r!)
Expand Down Expand Up @@ -890,3 +953,4 @@ function *(p::r2rFFTWPlan{T,K,true}, x::StridedArray{T}) where {T,K}
unsafe_execute!(p, x, x)
return x
end

30 changes: 30 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -528,3 +528,33 @@ end
@test occursin("dft-thr", string(p2))
end
end

begin
@testset "PaddedRFFTArray creation" begin
a = rand(Complex{Float64}, (8, 4, 4))
b = FFTW.PaddedRFFTArray(a)
c = copy(b)
@test a == FFTW.complex_view(b)
@test c == b
@test FFTW.real_view(c) == FFTW.real_view(b)
@test FFTW.complex_view(c) == FFTW.complex_view(b)
end

@testset "irfft!" begin
a = rand(Float64, (8, 4, 4))
c = rfft(a)
d = copy(c)
e = irfft!(d, size(a, 1))
@test a ≈ e
@test irfft(c, size(a, 1)) ≈ e
@test d === parent(parent(e))
@test d != c
@test irfft(c, size(a, 1), 1:2) ≈ irfft!(copy(c), size(a, 1), 1:2)
end

@testset "brfft!" begin
a = rand(Float64,(4,4))
b = rfft(a)
@test (brfft!(b, size(a, 1)) ./ 16) ≈ a
end
end