Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make transforms work on CUDA arrays #48

Merged
merged 7 commits into from
Jun 21, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "PencilFFTs"
uuid = "4a48f351-57a6-4416-9ec4-c37015456aae"
authors = ["Juan Ignacio Polanco <[email protected]>"]
version = "0.13.6"
version = "0.14.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand All @@ -16,7 +16,7 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
AbstractFFTs = "1"
FFTW = "1"
MPI = "0.19"
PencilArrays = "0.16, 0.17"
PencilArrays = "0.17"
Reexport = "1"
TimerOutputs = "0.5"
julia = "1.6"
7 changes: 4 additions & 3 deletions src/Transforms/Transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ and [`FFTW.jl`](https://juliamath.github.io/FFTW.jl/stable/fft/).
"""
module Transforms

using AbstractFFTs
using FFTW

# Operations defined for custom plans (currently IdentityPlan).
Expand Down Expand Up @@ -39,19 +40,19 @@ The only custom plan defined in this module is [`IdentityPlan`](@ref).
The user can define other custom plans that are also subtypes of
`AbstractCustomPlan`.

Note that [`plan`](@ref) returns a subtype of either `FFTW.FFTWPlan` or
Note that [`plan`](@ref) returns a subtype of either `AbstractFFTs.Plan` or
`AbstractCustomPlan`.
"""
abstract type AbstractCustomPlan end

"""
Plan = Union{FFTW.FFTWPlan, AbstractCustomPlan}
Plan = Union{AbstractFFTs.Plan, AbstractCustomPlan}

Union type representing any plan returned by [`plan`](@ref).

See also [`AbstractCustomPlan`](@ref).
"""
const Plan = Union{FFTW.FFTWPlan, AbstractCustomPlan}
const Plan = Union{AbstractFFTs.Plan, AbstractCustomPlan}

"""
plan(transform::AbstractTransform, A, [dims];
Expand Down
8 changes: 4 additions & 4 deletions src/Transforms/c2c.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,10 @@ length_output(::TransformC2C, length_in::Integer) = length_in
eltype_output(::TransformC2C, ::Type{Complex{T}}) where {T <: FFTReal} = Complex{T}
eltype_input(::TransformC2C, ::Type{T}) where {T <: FFTReal} = Complex{T}

plan(::FFT, args...; kwargs...) = FFTW.plan_fft(args...; kwargs...)
plan(::FFT!, args...; kwargs...) = FFTW.plan_fft!(args...; kwargs...)
plan(::BFFT, args...; kwargs...) = FFTW.plan_bfft(args...; kwargs...)
plan(::BFFT!, args...; kwargs...) = FFTW.plan_bfft!(args...; kwargs...)
plan(::FFT, A::AbstractArray, args...; kwargs...) = FFTW.plan_fft(A, args...; kwargs...)
plan(::FFT!, A::AbstractArray, args...; kwargs...) = FFTW.plan_fft!(A, args...; kwargs...)
plan(::BFFT, A::AbstractArray, args...; kwargs...) = FFTW.plan_bfft(A, args...; kwargs...)
plan(::BFFT!, A::AbstractArray, args...; kwargs...) = FFTW.plan_bfft!(A, args...; kwargs...)

binv(::FFT, d) = BFFT()
binv(::FFT!, d) = BFFT!()
Expand Down
4 changes: 2 additions & 2 deletions src/Transforms/r2c.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ eltype_output(::BRFFT, ::Type{Complex{T}}) where {T <: FFTReal} = T
eltype_input(::RFFT, ::Type{T}) where {T <: FFTReal} = T
eltype_input(::BRFFT, ::Type{T}) where {T <: FFTReal} = Complex{T}

plan(::RFFT, args...; kwargs...) = FFTW.plan_rfft(args...; kwargs...)
plan(::RFFT, A::AbstractArray, args...; kwargs...) = FFTW.plan_rfft(A, args...; kwargs...)

# NOTE: unlike most FFTW plans, this function also requires the length `d` of
# the transform output along the first transformed dimension.
function plan(tr::BRFFT, A, dims; kwargs...)
function plan(tr::BRFFT, A::AbstractArray, dims; kwargs...)
Nin = size(A, first(dims)) # input length along first dimension
d = length_output(tr, Nin)
FFTW.plan_brfft(A, d, dims; kwargs...)
Expand Down
2 changes: 1 addition & 1 deletion src/Transforms/r2r.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ eltype_output(::AnyR2R, ::Type{T}) where {T} = T
# `length(dims)` is known by the compiler. This will be the case if `dims` is a
# tuple or a scalar value (e.g. `(1, 3)` or `1`), but not if it is a range (e.g.
# `2:3`).
function plan(transform::AnyR2R, A, dims; kwargs...)
function plan(transform::AnyR2R, A::AbstractArray, dims; kwargs...)
kd = kind(transform)
K = ntuple(_ -> kd, length(dims))
R = FFTW.r2rFFTWPlan{T,K} where {T} # try to guess the return type
Expand Down
7 changes: 4 additions & 3 deletions src/operations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ function _apply_plans!(

if dir === Val(FFTW.BACKWARD)
# Scale transform.
ldiv!(scale_factor(full_plan), y)
y ./= scale_factor(full_plan)
end

y
Expand All @@ -165,7 +165,7 @@ function _apply_plans!(

if dir === Val(FFTW.BACKWARD)
# Scale transform.
ldiv!(scale_factor(full_plan), first(A))
first(A) ./= scale_factor(full_plan)
end

A
Expand Down Expand Up @@ -252,7 +252,8 @@ end
_make_pairs(::Tuple{}, ::Tuple{}) = ()

@inline function _temporary_pencil_array(
::Type{T}, p::Pencil, buf::Vector{UInt8}, extra_dims::Dims) where {T}
::Type{T}, p::Pencil, buf::DenseVector{UInt8}, extra_dims::Dims,
) where {T}
# Create "unsafe" pencil array wrapping buffer data.
dims = (size_local(p, MemoryOrder())..., extra_dims...)
nb = prod(dims) * sizeof(T)
Expand Down
40 changes: 31 additions & 9 deletions src/plans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ struct PencilFFTPlan{
G <: GlobalFFTParams,
P <: NTuple{Nt, PencilPlan1D},
TransposeMethod <: AbstractTransposeMethod,
Buffer <: DenseVector{UInt8},
} <: AbstractFFTs.Plan{T}

global_params :: G
Expand All @@ -227,9 +228,10 @@ struct PencilFFTPlan{
# `method` parameter passed to `transpose!`
transpose_method :: TransposeMethod

# TODO can I reuse the Pencil buffers (send_buf, recv_buf) to reduce allocations?
# Temporary data buffers.
ibuf :: Vector{UInt8}
obuf :: Vector{UInt8}
ibuf :: Buffer
obuf :: Buffer

# Runtime timing.
# Should be used along with the @timeit_debug macro, to be able to turn it
Expand All @@ -244,14 +246,15 @@ struct PencilFFTPlan{
transpose_method::AbstractTransposeMethod =
Transpositions.PointToPoint(),
timer::TimerOutput = TimerOutput(),
ibuf = UInt8[], obuf = UInt8[], # temporary data buffers
ibuf = _make_fft_buffer(A), obuf = _make_fft_buffer(A),
)
T = eltype(A)
dims_global = size_global(pencil(A), LogicalOrder())
pen = pencil(A)
dims_global = size_global(pen, LogicalOrder())
g = GlobalFFTParams(dims_global, transforms, real(T))
check_input_array(A, g)
inplace = is_inplace(g)
fftw_kw = (; flags = fftw_flags, timelimit = fftw_timelimit)
fftw_kw = _make_fft_kwargs(pen; flags = fftw_flags, timelimit = fftw_timelimit)

# Options for creation of 1D plans.
plans = _create_plans(
Expand All @@ -266,7 +269,10 @@ struct PencilFFTPlan{

# If the plan is in-place, the buffers won't be needed anymore, so we
# free the memory.
# TODO this assumes that buffers are not shared with the Pencil object!
if inplace
@assert all(x -> x !== ibuf, (pen.send_buf, pen.recv_buf))
@assert all(x -> x !== obuf, (pen.send_buf, pen.recv_buf))
resize!.((ibuf, obuf), 0)
end

Expand All @@ -279,15 +285,18 @@ struct PencilFFTPlan{
TM = typeof(transpose_method)
t = topology(A)
Nd = ndims(t)
Buffer = typeof(ibuf)

new{T, N, inplace, Nt, Nd, Ne, G, P, TM}(
g, t, edims, plans, scale, transpose_method, ibuf, obuf, timer)
new{T, N, inplace, Nt, Nd, Ne, G, P, TM, Buffer}(
g, t, edims, plans, scale, transpose_method, ibuf, obuf, timer,
)
end
end

function PencilFFTPlan(
pen::Pencil{Nt}, transforms::AbstractTransformList{Nt}, ::Type{Tr} = Float64;
extra_dims::Dims = (), timer = TimerOutput(), ibuf = UInt8[], kws...,
extra_dims::Dims = (), timer = TimerOutput(), ibuf = _make_fft_buffer(pen),
kws...,
) where {Nt, Tr <: FFTReal}
T = _input_data_type(Tr, transforms...)
A = _temporary_pencil_array(T, pen, ibuf, extra_dims)
Expand All @@ -310,6 +319,18 @@ function PencilFFTPlan(A, transform::AbstractTransform, args...; kws...)
PencilFFTPlan(A, transforms, args...; kws...)
end

_make_fft_buffer(p::Pencil) = similar(p.send_buf, UInt8, 0) :: DenseVector{UInt8}
_make_fft_buffer(A::PencilArray) = _make_fft_buffer(pencil(A))

# We decide on passing FFTW flags or not depending on the type of underlying array.
# In particular, note that CUFFT doesn't support keyword arguments (such as
# FFTW.MEASURE), and therefore we silently suppress them.
# TODO
# - use a more generic way of differentiating between CPU and GPU arrays
_make_fft_kwargs(p::Pencil; kws...) = _make_fft_kwargs(p.send_buf; kws...)
_make_fft_kwargs(::Array; kws...) = kws # CPU arrays
_make_fft_kwargs(::AbstractArray; kws...) = (;) # GPU arrays: suppress keyword arguments

@inline _ndims_transformable(dims::Dims) = length(dims)
@inline _ndims_transformable(p::Pencil) = ndims(p)
@inline _ndims_transformable(A::PencilArray) = _ndims_transformable(pencil(A))
Expand Down Expand Up @@ -501,7 +522,8 @@ end

function _make_1d_fft_plan(
dim::Val{n}, ::Type{Ti}, A_fw::PencilArray, A_bw::PencilArray,
transform_fw::AbstractTransform; fftw_kw) where {n, Ti}
transform_fw::AbstractTransform; fftw_kw,
) where {n, Ti}
Pi = pencil(A_fw)
Po = pencil(A_bw)
perm = permutation(Pi)
Expand Down