Skip to content

Commit

Permalink
Make transforms work on CUDA arrays (#48)
Browse files Browse the repository at this point in the history
* Allow using GPU arrays as FFT buffers

* Explicitly list array A in `plan` signature

* Don't pass FFT kwargs for GPU arrays

In particular, `plan_rfft(::CuArray, args...; kws...)` fails when `kws`
is not empty.
Therefore, for GPU arrays, we avoid passing any keyword arguments to
planner functions.

* Enlarge Transforms.Plan union type

* Fix assertions

* Avoid scalar indexing in backwards transforms

* v0.14.0
  • Loading branch information
jipolanco authored Jun 21, 2022
1 parent cbe89dc commit cc8c381
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 24 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PencilFFTs"
uuid = "4a48f351-57a6-4416-9ec4-c37015456aae"
authors = ["Juan Ignacio Polanco <[email protected]>"]
version = "0.13.6"
version = "0.14.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -16,7 +16,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
AbstractFFTs = "1"
FFTW = "1"
MPI = "0.19"
PencilArrays = "0.16, 0.17"
PencilArrays = "0.17"
Reexport = "1"
TimerOutputs = "0.5"
julia = "1.6"
7 changes: 4 additions & 3 deletions src/Transforms/Transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and [`FFTW.jl`](https://juliamath.github.io/FFTW.jl/stable/fft/).
"""
module Transforms

using AbstractFFTs
using FFTW

# Operations defined for custom plans (currently IdentityPlan).
Expand Down Expand Up @@ -39,19 +40,19 @@ The only custom plan defined in this module is [`IdentityPlan`](@ref).
The user can define other custom plans that are also subtypes of
`AbstractCustomPlan`.
Note that [`plan`](@ref) returns a subtype of either `FFTW.FFTWPlan` or
Note that [`plan`](@ref) returns a subtype of either `AbstractFFTs.Plan` or
`AbstractCustomPlan`.
"""
abstract type AbstractCustomPlan end

"""
Plan = Union{FFTW.FFTWPlan, AbstractCustomPlan}
Plan = Union{AbstractFFTs.Plan, AbstractCustomPlan}
Union type representing any plan returned by [`plan`](@ref).
See also [`AbstractCustomPlan`](@ref).
"""
const Plan = Union{FFTW.FFTWPlan, AbstractCustomPlan}
const Plan = Union{AbstractFFTs.Plan, AbstractCustomPlan}

"""
plan(transform::AbstractTransform, A, [dims];
Expand Down
8 changes: 4 additions & 4 deletions src/Transforms/c2c.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ length_output(::TransformC2C, length_in::Integer) = length_in
eltype_output(::TransformC2C, ::Type{Complex{T}}) where {T <: FFTReal} = Complex{T}
eltype_input(::TransformC2C, ::Type{T}) where {T <: FFTReal} = Complex{T}

plan(::FFT, args...; kwargs...) = FFTW.plan_fft(args...; kwargs...)
plan(::FFT!, args...; kwargs...) = FFTW.plan_fft!(args...; kwargs...)
plan(::BFFT, args...; kwargs...) = FFTW.plan_bfft(args...; kwargs...)
plan(::BFFT!, args...; kwargs...) = FFTW.plan_bfft!(args...; kwargs...)
plan(::FFT, A::AbstractArray, args...; kwargs...) = FFTW.plan_fft(A, args...; kwargs...)
plan(::FFT!, A::AbstractArray, args...; kwargs...) = FFTW.plan_fft!(A, args...; kwargs...)
plan(::BFFT, A::AbstractArray, args...; kwargs...) = FFTW.plan_bfft(A, args...; kwargs...)
plan(::BFFT!, A::AbstractArray, args...; kwargs...) = FFTW.plan_bfft!(A, args...; kwargs...)

binv(::FFT, d) = BFFT()
binv(::FFT!, d) = BFFT!()
Expand Down
4 changes: 2 additions & 2 deletions src/Transforms/r2c.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ eltype_output(::BRFFT, ::Type{Complex{T}}) where {T <: FFTReal} = T
eltype_input(::RFFT, ::Type{T}) where {T <: FFTReal} = T
eltype_input(::BRFFT, ::Type{T}) where {T <: FFTReal} = Complex{T}

plan(::RFFT, args...; kwargs...) = FFTW.plan_rfft(args...; kwargs...)
plan(::RFFT, A::AbstractArray, args...; kwargs...) = FFTW.plan_rfft(A, args...; kwargs...)

# NOTE: unlike most FFTW plans, this function also requires the length `d` of
# the transform output along the first transformed dimension.
function plan(tr::BRFFT, A, dims; kwargs...)
function plan(tr::BRFFT, A::AbstractArray, dims; kwargs...)
Nin = size(A, first(dims)) # input length along first dimension
d = length_output(tr, Nin)
FFTW.plan_brfft(A, d, dims; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/Transforms/r2r.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ eltype_output(::AnyR2R, ::Type{T}) where {T} = T
# `length(dims)` is known by the compiler. This will be the case if `dims` is a
# tuple or a scalar value (e.g. `(1, 3)` or `1`), but not if it is a range (e.g.
# `2:3`).
function plan(transform::AnyR2R, A, dims; kwargs...)
function plan(transform::AnyR2R, A::AbstractArray, dims; kwargs...)
kd = kind(transform)
K = ntuple(_ -> kd, length(dims))
R = FFTW.r2rFFTWPlan{T,K} where {T} # try to guess the return type
Expand Down
7 changes: 4 additions & 3 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ function _apply_plans!(

if dir === Val(FFTW.BACKWARD)
# Scale transform.
ldiv!(scale_factor(full_plan), y)
y ./= scale_factor(full_plan)
end

y
Expand All @@ -165,7 +165,7 @@ function _apply_plans!(

if dir === Val(FFTW.BACKWARD)
# Scale transform.
ldiv!(scale_factor(full_plan), first(A))
first(A) ./= scale_factor(full_plan)
end

A
Expand Down Expand Up @@ -252,7 +252,8 @@ end
_make_pairs(::Tuple{}, ::Tuple{}) = ()

@inline function _temporary_pencil_array(
::Type{T}, p::Pencil, buf::Vector{UInt8}, extra_dims::Dims) where {T}
::Type{T}, p::Pencil, buf::DenseVector{UInt8}, extra_dims::Dims,
) where {T}
# Create "unsafe" pencil array wrapping buffer data.
dims = (size_local(p, MemoryOrder())..., extra_dims...)
nb = prod(dims) * sizeof(T)
Expand Down
40 changes: 31 additions & 9 deletions src/plans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ struct PencilFFTPlan{
G <: GlobalFFTParams,
P <: NTuple{Nt, PencilPlan1D},
TransposeMethod <: AbstractTransposeMethod,
Buffer <: DenseVector{UInt8},
} <: AbstractFFTs.Plan{T}

global_params :: G
Expand All @@ -227,9 +228,10 @@ struct PencilFFTPlan{
# `method` parameter passed to `transpose!`
transpose_method :: TransposeMethod

# TODO can I reuse the Pencil buffers (send_buf, recv_buf) to reduce allocations?
# Temporary data buffers.
ibuf :: Vector{UInt8}
obuf :: Vector{UInt8}
ibuf :: Buffer
obuf :: Buffer

# Runtime timing.
# Should be used along with the @timeit_debug macro, to be able to turn it
Expand All @@ -244,14 +246,15 @@ struct PencilFFTPlan{
transpose_method::AbstractTransposeMethod =
Transpositions.PointToPoint(),
timer::TimerOutput = TimerOutput(),
ibuf = UInt8[], obuf = UInt8[], # temporary data buffers
ibuf = _make_fft_buffer(A), obuf = _make_fft_buffer(A),
)
T = eltype(A)
dims_global = size_global(pencil(A), LogicalOrder())
pen = pencil(A)
dims_global = size_global(pen, LogicalOrder())
g = GlobalFFTParams(dims_global, transforms, real(T))
check_input_array(A, g)
inplace = is_inplace(g)
fftw_kw = (; flags = fftw_flags, timelimit = fftw_timelimit)
fftw_kw = _make_fft_kwargs(pen; flags = fftw_flags, timelimit = fftw_timelimit)

# Options for creation of 1D plans.
plans = _create_plans(
Expand All @@ -266,7 +269,10 @@ struct PencilFFTPlan{

# If the plan is in-place, the buffers won't be needed anymore, so we
# free the memory.
# TODO this assumes that buffers are not shared with the Pencil object!
if inplace
@assert all(x -> x !== ibuf, (pen.send_buf, pen.recv_buf))
@assert all(x -> x !== obuf, (pen.send_buf, pen.recv_buf))
resize!.((ibuf, obuf), 0)
end

Expand All @@ -279,15 +285,18 @@ struct PencilFFTPlan{
TM = typeof(transpose_method)
t = topology(A)
Nd = ndims(t)
Buffer = typeof(ibuf)

new{T, N, inplace, Nt, Nd, Ne, G, P, TM}(
g, t, edims, plans, scale, transpose_method, ibuf, obuf, timer)
new{T, N, inplace, Nt, Nd, Ne, G, P, TM, Buffer}(
g, t, edims, plans, scale, transpose_method, ibuf, obuf, timer,
)
end
end

function PencilFFTPlan(
pen::Pencil{Nt}, transforms::AbstractTransformList{Nt}, ::Type{Tr} = Float64;
extra_dims::Dims = (), timer = TimerOutput(), ibuf = UInt8[], kws...,
extra_dims::Dims = (), timer = TimerOutput(), ibuf = _make_fft_buffer(pen),
kws...,
) where {Nt, Tr <: FFTReal}
T = _input_data_type(Tr, transforms...)
A = _temporary_pencil_array(T, pen, ibuf, extra_dims)
Expand All @@ -310,6 +319,18 @@ function PencilFFTPlan(A, transform::AbstractTransform, args...; kws...)
PencilFFTPlan(A, transforms, args...; kws...)
end

_make_fft_buffer(p::Pencil) = similar(p.send_buf, UInt8, 0) :: DenseVector{UInt8}
_make_fft_buffer(A::PencilArray) = _make_fft_buffer(pencil(A))

# We decide on passing FFTW flags or not depending on the type of underlying array.
# In particular, note that CUFFT doesn't support keyword arguments (such as
# FFTW.MEASURE), and therefore we silently suppress them.
# TODO
# - use a more generic way of differentiating between CPU and GPU arrays
_make_fft_kwargs(p::Pencil; kws...) = _make_fft_kwargs(p.send_buf; kws...)
_make_fft_kwargs(::Array; kws...) = kws # CPU arrays
_make_fft_kwargs(::AbstractArray; kws...) = (;) # GPU arrays: suppress keyword arguments

@inline _ndims_transformable(dims::Dims) = length(dims)
@inline _ndims_transformable(p::Pencil) = ndims(p)
@inline _ndims_transformable(A::PencilArray) = _ndims_transformable(pencil(A))
Expand Down Expand Up @@ -501,7 +522,8 @@ end

function _make_1d_fft_plan(
dim::Val{n}, ::Type{Ti}, A_fw::PencilArray, A_bw::PencilArray,
transform_fw::AbstractTransform; fftw_kw) where {n, Ti}
transform_fw::AbstractTransform; fftw_kw,
) where {n, Ti}
Pi = pencil(A_fw)
Po = pencil(A_bw)
perm = permutation(Pi)
Expand Down

2 comments on commit cc8c381

@jipolanco
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/62783

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.0 -m "<description of version>" cc8c3813f771cb6b0e385544cd2ede1e1a756588
git push origin v0.14.0

Please sign in to comment.