Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added rfft! and irfft! functionality through PaddedRFFTArray type. #54

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
35 changes: 12 additions & 23 deletions src/rfft!.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,27 @@ export PaddedRFFTArray, plan_rfft!, rfft!, plan_irfft!, plan_brfft!, brfft!, irf
# custom getindex and setindex! below. Hopefully, once the performance issues with ReinterpretArray
# are solved we can just index the reinterpret array directly.

struct PaddedRFFTArray{T<:fftwReal,N,L} <: DenseArray{Complex{T},N}
struct PaddedRFFTArray{T<:fftwReal,N,Nm1,L} <: DenseArray{Complex{T},N}
data::Array{T,N}
r::SubArray{T,N,Array{T,N},NTuple{N,UnitRange{Int}},L} # Real view skipping padding
r::SubArray{T,N,Array{T,N},Tuple{UnitRange{Int},Vararg{Base.Slice{Base.OneTo{Int}},Nm1}},L} # Real view skipping padding
c::Base.ReinterpretArray{Complex{T},N,T,Array{T,N}}

function PaddedRFFTArray{T,N}(rr::Array{T,N},nx::Int) where {T<:fftwReal,N}
rrsize = size(rr)
fsize = rrsize[1]
function PaddedRFFTArray{T,N,Nm1,L}(rr::Array{T,N},nx::Int) where {T<:fftwReal,N,Nm1,L}
fsize = size(rr)[1]
iseven(fsize) || throw(
ArgumentError("First dimension of allocated array must have even number of elements"))
(nx == fsize-2 || nx == fsize-1) || throw(
ArgumentError("Number of elements on the first dimension of array must be either 1 or 2 less than the number of elements on the first dimension of the allocated array"))
fsize = fsize÷2
csize = (fsize, rrsize[2:end]...)
c = reinterpret(Complex{T}, rr)
rsize = (nx,rrsize[2:end]...)
r = view(rr,(1:l for l in rsize)...)
return new{T, N, N === 1 ? true : false}(rr,r,c)
r = view(rr, 1:nx, ntuple(i->Colon(),Nm1)...)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can pass Val{Nm1}() here to get a compiler-unrolled ntuple on 0.7. On 0.6 you need Val{Nm1}.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because Val(n) is a pure function now I can just put Val(Nm1) instead of Val{Nm1}(), right?

return new{T, N, Nm1, L}(rr,r,c)
end # function
end # struct

@generated function PaddedRFFTArray{T,N}(rr::Array{T,N},nx::Int) where {T<:fftwReal,N}
:(PaddedRFFTArray{T,N,$(N-1),$(N === 1 ? true : false)}(rr,nx))
end

@inline real(S::PaddedRFFTArray) = S.r

@inline complex_view(S::PaddedRFFTArray) = S.c
Expand Down Expand Up @@ -175,13 +175,7 @@ function plan_rfft!(X::PaddedRFFTArray{T,N}, region;
timelimit::Real=NO_TIMELIMIT) where {T<:fftwReal,N}

(1 in region) || throw(ArgumentError("The first dimension must always be transformed"))
if flags&ESTIMATE != 0
p = rFFTWPlan{T,FORWARD,true,N}(real(X), complex_view(X), region, flags, timelimit)
else
x = similar(X)
p = rFFTWPlan{T,FORWARD,true,N}(real(x), complex_view(x), region, flags, timelimit)
end
return p
return rFFTWPlan{T,FORWARD,true,N}(real(X), complex_view(X), region, flags, timelimit)
end

plan_rfft!(f::PaddedRFFTArray;kws...) = plan_rfft!(f, 1:ndims(f); kws...)
Expand All @@ -204,12 +198,7 @@ function plan_brfft!(X::PaddedRFFTArray{T,N}, region;
flags::Integer=ESTIMATE,
timelimit::Real=NO_TIMELIMIT) where {T<:fftwReal,N}
(1 in region) || throw(ArgumentError("The first dimension must always be transformed"))
if flags&ESTIMATE != 0
return rFFTWPlan{Complex{T},BACKWARD,true,N}(complex_view(X), real(X), region, flags,timelimit)
else
a = similar(X)
return rFFTWPlan{Complex{T},BACKWARD,true,N}(complex_view(a), real(a), region, flags,timelimit)
end
return rFFTWPlan{Complex{T},BACKWARD,true,N}(complex_view(X), real(X), region, flags,timelimit)
end

plan_brfft!(f::PaddedRFFTArray;kws...) = plan_brfft!(f,1:ndims(f);kws...)
Expand Down