Skip to content

Commit

Permalink
Merge pull request #308 from yuehhua/cartesian
Browse files Browse the repository at this point in the history
scatter and gather support element type of idx to be CartesianIndex
  • Loading branch information
CarloLucibello authored Apr 21, 2021
2 parents a04e916 + e499d3a commit b1633f5
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 59 deletions.
19 changes: 5 additions & 14 deletions src/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,11 @@ or multiple `dst` columns.
See [`gather`](@ref) for an allocating version.
"""
function gather!(dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx, Nidx}) where
{Tdst, Tsrc, Ndst, Nsrc, Nidx, Tidx <: IntOrIntTuple}

M = typelength(Tidx)
d = Ndst - Nidx
d == Nsrc - M || throw(ArgumentError("Incompatible input shapes."))
size(dst)[1:d] == size(src)[1:d] || throw(ArgumentError("Incompatible input shapes."))
size(dst)[d+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))

colons = ntuple(i -> Colon(), d)
function gather!(dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
dims = _check_dims(src, dst, idx)
colons = ntuple(i -> Colon(), dims)
for k in CartesianIndices(idx)
view(dst, colons..., k) .= view(src, colons..., idx[k]...)
_view(dst, colons, k) .= _view(src, colons, idx[k])
end
return dst
end
Expand Down Expand Up @@ -64,7 +55,7 @@ See [`gather!`](@ref) for an in-place version.
"""
function gather(src::AbstractArray{Tsrc, Nsrc},
idx::AbstractArray{Tidx, Nidx}) where
{Tsrc, Nsrc, Nidx, Tidx<:IntOrIntTuple}
{Tsrc, Nsrc, Nidx, Tidx}

M = typelength(Tidx)
dstsize = (size(src)[1:Nsrc-M]..., size(idx)...)
Expand Down
100 changes: 55 additions & 45 deletions src/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,39 @@
# - ∇scatter_src!
#

function _check_dims(Ndst, Nsrc, N, Nidx)
@assert Ndst - N == Nsrc - Nidx "Incompatible input shapes of (dst, src, idx) = ($Ndst, $Nsrc, $Nidx)."
dims = Ndst - N
if dims < 0
throw(ArgumentError("dims must be non-negative but got dims=$dims."))
end
typelength(::Type{<:Number}) = 1
typelength(::Type{<:NTuple{M}}) where M = M
typelength(::Type{CartesianIndex{M}}) where M = M

function _check_dims(X::AbstractArray{Tx,Nx},
Y::AbstractArray{Ty,Ny},
idx::AbstractArray{Tidx,Nidx}) where
{Tx,Ty,Tidx<:IntOrIntTuple,Nx,Ny,Nidx}
M = typelength(Tidx)
dims = _check_dims(Nx, Ny, M, Nidx)
size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes."))
size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
return dims
end

typelength(::Type{<:Number}) = 1
typelength(::Type{<:NTuple{M}}) where M = M
function _check_dims(X::AbstractArray{Tx,Nx},
Y::AbstractArray{Ty,Ny},
idx::AbstractArray{CartesianIndex{M},Nidx}) where {Tx,Ty,Nx,Ny,M,Nidx}
dims = _check_dims(Nx, Ny, M, Nidx)
size(X)[1:dims] == size(Y)[1:dims] || throw(ArgumentError("Incompatible input shapes."))
size(Y)[dims+1:end] == size(idx) || throw(ArgumentError("Incompatible input shapes."))
return dims
end

function _check_dims(Nx, Ny, M, Nidx)
@assert Nx - M == Ny - Nidx "Incompatible input shapes of (dst, src, idx) = ($Nx, $Ny, $Nidx)."
dims = Nx - M
dims < 0 && throw(ArgumentError("dims must be non-negative but got dims=$dims."))
return dims
end

_view(X, colons, k) = view(X, colons..., k...)
_view(X, colons, k::Union{Integer, CartesianIndex}) = view(X, colons..., k)

"""
scatter!(op, dst, src, idx)
Expand All @@ -42,30 +64,18 @@ index of `dst` and the value of `idx` must indicate the last few dimensions of `
Once the dimensions match, arrays are aligned automatically. The value of `idx` can be
`Int` or `Tuple` type.
"""
function scatter!(op,
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tdst,Tsrc,Tidx<:IntOrIntTuple,Ndst,Nsrc,Nidx}
M = typelength(Tidx)
dims = _check_dims(Ndst, Nsrc, M, Nidx)
scatter!(op, dst, src, idx, Val(dims))
end

function scatter!(op, dst::AbstractArray{Tdst}, src::AbstractArray{Tsrc}, idx::AbstractArray{<:IntOrIntTuple},
dims::Val{N}) where {Tdst,Tsrc,N}
function scatter!(op, dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
dims = _check_dims(dst, src, idx)
colons = Base.ntuple(_->Colon(), dims)
for k in CartesianIndices(idx)
dst_v = view(dst, colons..., idx[k]...)
src_v = view(src, colons..., k)
dst_v = _view(dst, colons, idx[k])
src_v = _view(src, colons, k)
dst_v .= (op).(dst_v, src_v)
end
dst
end

function scatter!(op::typeof(mean),
dst::AbstractArray{Tdst,Ndst},
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {Tdst,Tsrc,Ndst,Nsrc,Nidx}
function scatter!(op::typeof(mean), dst::AbstractArray, src::AbstractArray, idx::AbstractArray)
Ns = scatter!(+, zero(dst), one.(src), idx)
dst_ = scatter!(+, zero(dst), src, idx)
dst .+= safe_div.(dst_, Ns)
Expand Down Expand Up @@ -93,55 +103,55 @@ function scatter end

for op in [+, -]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, Base.reduce_empty(+, T))
dst = similar(src, Tsrc, dstsize)
fill!(dst, Base.reduce_empty(+, Tsrc))
scatter!(op, dst, src, idx)
end
end

for op in [*, /]
@eval function scatter(op::typeof($op),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, Base.reduce_empty(*, T))
dst = similar(src, Tsrc, dstsize)
fill!(dst, Base.reduce_empty(*, Tsrc))
scatter!(op, dst, src, idx)
end
end

function scatter(op::typeof(max),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, typemin(T))
dst = similar(src, Tsrc, dstsize)
fill!(dst, typemin(Tsrc))
scatter!(op, dst, src, idx)
end

function scatter(op::typeof(min),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
fill!(dst, typemax(T))
dst = similar(src, Tsrc, dstsize)
fill!(dst, typemax(Tsrc))
scatter!(op, dst, src, idx)
end

function scatter(op::typeof(mean),
src::AbstractArray{T,Nsrc},
idx::AbstractArray{<:IntOrIntTuple,Nidx}) where {T,Nsrc,Nidx}
FT = float(T)
src::AbstractArray{Tsrc,Nsrc},
idx::AbstractArray{Tidx,Nidx}) where {Tsrc,Tidx,Nsrc,Nidx}
FT = float(Tsrc)
dims = Nsrc - Nidx
dstsize = (size(src)[1:dims]..., maximum_dims(idx)...)
dst = similar(src, T, dstsize)
dst = similar(src, Tsrc, dstsize)
fill!(dst, Base.reduce_empty(+, FT))
scatter!(op, dst, src, idx)
end
1 change: 1 addition & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ The maximum of each dimension in the element is computed.
"""
maximum_dims(dims::AbstractArray{<:Integer}) = (maximum(dims), )
maximum_dims(dims::AbstractArray{NTuple{N, T}}) where {N,T} = ntuple(i -> maximum(x->x[i], dims), N)
maximum_dims(dims::AbstractArray{CartesianIndex{N}}) where {N} = ntuple(i -> maximum(x->x[i], dims), N)
30 changes: 30 additions & 0 deletions test/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,33 @@ end
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
end

@testset "gather cartesian index" begin
T = Float32

## 2d src, 1d index of 2-tuples -> 1d output
src = T[3 5 7
4 6 8]

index = CartesianIndex.([(1,1), (1,2), (1,3), (2,1), (2,2), (2,3)])

output = T[3, 5, 7, 4, 6, 8]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,1}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
@test y == output

## 3d src, 2d index of 2-tuples -> 3d output
n1, nsrc, nidx = 2, 3, 6
src = rand(Float32, n1, nsrc, nsrc)
index = [CartesianIndex((rand(1:nsrc), rand(1:nsrc))) for i=1:nidx, j=1:nidx]

y = gather(src, index)
M = NNlib.typelength(eltype(index))
Nsrc = ndims(src)
@test y isa Array{T,3}
@test size(y) == (size(src)[1:Nsrc-M]..., size(index)...)
end
4 changes: 4 additions & 0 deletions test/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ idxs = Dict(
:tup => [(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)],
:car => CartesianIndex.(
[(1,) (2,) (3,) (4,);
(4,) (2,) (1,) (3,);
(3,) (5,) (5,) (3,)]),
)
res = Dict(
(+, 0, true) => [5, 6, 9, 8, 9],
Expand Down
4 changes: 4 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,8 @@
ind3 = [(3,4,5) (1,2,3) (2,3,9);
(4,6,2) (5,3,2) (4,4,4)]
@test NNlib.maximum_dims(ind3) == (5,6,9)
ind4 = CartesianIndex.(
[(3,4,5) (1,2,3) (2,3,9);
(4,6,2) (5,3,2) (4,4,4)])
@test NNlib.maximum_dims(ind4) == (5,6,9)
end

2 comments on commit b1633f5

@DhairyaLGandhi
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/35128

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.19 -m "<description of version>" b1633f5f534ad1fd431e82331e2a5aa1c337edb1
git push origin v0.7.19

Please sign in to comment.