Skip to content

Commit

Permalink
Merge pull request #342 from EnzymeAD/ap/upsampling
Browse files Browse the repository at this point in the history
* fix: manually zero out the lower triangular and upper triangular values

* fix: only do it in tests

* revert: change in Ops.cholesky

* revert: remove unnecessary changes

* fix: preserve parent array tracking for reshape

* test: writing to a reshaped array

* test: upsample_nearest

* fix: test failures due to wrappers

* fix: handle lazy transpose/adjoint correctly

* fix: handle wrappers in NNlibExt correctly

* fix: more reshaped wrappers handling

* fix: dispatches to avoid ambiguity

* fix: handle diagonal wrapper gracefully

* fix: compile wrapped concrete array conversion to arrays

* feat: more wrapped ConcreteRArray handling

* chore: apply suggestions from code review

* refactor: rearrange the tests

* test: add test that fails on incorrect reshape dims ordering
  • Loading branch information
avik-pal authored Dec 11, 2024
2 parents 816e789 + 107040d commit 814e9c0
Show file tree
Hide file tree
Showing 12 changed files with 388 additions and 241 deletions.
4 changes: 2 additions & 2 deletions ext/ReactantArrayInterfaceExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module ReactantArrayInterfaceExt

using ArrayInterface: ArrayInterface
using Reactant:
Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray
Reactant, RArray, ConcreteRArray, ConcreteRNumber, TracedRNumber, TracedRArray, Ops

ArrayInterface.can_setindex(::Type{<:RArray}) = false
ArrayInterface.fast_scalar_indexing(::Type{<:RArray}) = false
Expand All @@ -14,7 +14,7 @@ function ArrayInterface.aos_to_soa(x::AbstractArray{<:ConcreteRNumber{T}}) where
end

function ArrayInterface.aos_to_soa(x::AbstractArray{<:TracedRNumber{T}}) where {T}
return reshape(vcat(x...), size(x))
return Ops.reshape(vcat(x...), size(x)...)
end

end
145 changes: 62 additions & 83 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@ module ReactantNNlibExt
using NNlib
using GPUArraysCore: @allowscalar
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
Reactant,
Ops,
TracedRArray,
AnyTracedRArray,
materialize_traced_array,
MLIR,
TracedRNumber,
get_mlir_data,
set_mlir_data!
using ReactantCore: @trace
using LinearAlgebra: LinearAlgebra, triu

Expand All @@ -12,14 +20,7 @@ for (jlop, hloop) in (
(:(NNlib.sigmoid_fast), :logistic),
(:(NNlib.sigmoid), :logistic),
)
@eval function $(jlop)(x::TracedRNumber{T}) where {T}
return TracedRNumber{T}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
),
)
end
@eval $(jlop)(x::TracedRNumber) = Ops.$(hloop)(x)
end

function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
Expand Down Expand Up @@ -82,13 +83,6 @@ function NNlib.conv!(
kernel_input_dim = N - 1
kernel_output_dim = N

output_spatial_shapes = map(input_spatial_dims) do i
K = kernel_size[i]
pl, pr = padding[2i - 1], padding[2i]
d = dilation[i]
s = stride[i]
return (size(x, i) + pl + pr - d * (K - 1) - 1) ÷ s + 1
end
output_batch_dim = input_batch_dim
output_feature_dim = input_feature_dim
output_spatial_dims = input_spatial_dims
Expand Down Expand Up @@ -119,8 +113,8 @@ function NNlib.conv!(
end

conv = Reactant.MLIR.Dialects.stablehlo.convolution(
x.mlir_data,
weight.mlir_data;
get_mlir_data(x),
get_mlir_data(weight);
result_0=result_type,
window_strides=collect(stride),
padding,
Expand All @@ -130,7 +124,7 @@ function NNlib.conv!(
feature_group_count,
batch_group_count=1,
)
y.mlir_data = Reactant.MLIR.IR.result(conv)
set_mlir_data!(y, Reactant.MLIR.IR.result(conv))
return y
end

Expand Down Expand Up @@ -165,7 +159,9 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
output_shape = (output_spatial_shapes..., size(x, N - 1), size(x, N))
result_type = Reactant.MLIR.IR.TensorType(output_shape, Reactant.MLIR.IR.Type(T))

unranked = Reactant.MLIR.IR.TensorType((), eltype(Reactant.MLIR.IR.type(x.mlir_data)))
unranked = Reactant.MLIR.IR.TensorType(
(), eltype(Reactant.MLIR.IR.type(get_mlir_data(x)))
)
body =
let body = Reactant.MLIR.IR.Region(),
loc = Reactant.MLIR.IR.Location(),
Expand All @@ -189,7 +185,7 @@ function reduce_window(f, x::AnyTracedRArray{T,N}, pdims; init) where {T,N}
Reactant.MLIR.Dialects.stablehlo.constant(; value=attr)
)
reduction = Reactant.MLIR.Dialects.stablehlo.reduce_window(
[x.mlir_data],
[get_mlir_data(x)],
[init_value];
result_0=[result_type],
window_dimensions,
Expand All @@ -205,24 +201,24 @@ end
function NNlib.maxpool!(
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
) where {T}
y.mlir_data =
reduce_window(
Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T)
).mlir_data
res = reduce_window(
Reactant.MLIR.Dialects.stablehlo.maximum, T.(x), pdims; init=typemin(T)
)
set_mlir_data!(y, get_mlir_data(res))
return y
end

function NNlib.meanpool!(
y::TracedRArray{T}, x::AnyTracedRArray, pdims::NNlib.PoolDims
) where {T}
res = reduce_window(Reactant.MLIR.Dialects.stablehlo.add, T.(x), pdims; init=zero(T))
y.mlir_data = (res ./ T(prod(NNlib.kernel_size(pdims)))).mlir_data
set_mlir_data!(y, get_mlir_data(res ./ T(prod(NNlib.kernel_size(pdims)))))
return y
end

NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = permutedims(x, (2, 1, 3))
NNlib.batched_transpose(x::AnyTracedRArray{T,3}) where {T} = PermutedDimsArray(x, (2, 1, 3))
function NNlib.batched_adjoint(x::AnyTracedRArray{T,3}) where {T}
y = permutedims(x, (2, 1, 3))
y = NNlib.batched_transpose(x)
conj!(y)
return y
end
Expand All @@ -238,64 +234,47 @@ function NNlib.batched_mul!(
),
)
end

if size(x, 3) != size(y, 3)
B = max(size(x, 3), size(y, 3))
if size(x, 3) == 1
x = Reactant.broadcast_to_size(x, (size(x, 1), size(x, 2), B))
elseif size(y, 3) == 1
y = Reactant.broadcast_to_size(y, (size(y, 1), size(y, 2), B))
end
end

x = permutedims(x, (3, 1, 2))
y = permutedims(y, (3, 1, 2))

B = max(size(x, 1), size(y, 1))
out_shape = (B, size(x, 2), size(y, 3))
resty = MLIR.IR.TensorType(out_shape, eltype(MLIR.IR.type(res.mlir_data)))

if size(x, 1) != size(y, 1)
B = max(size(x, 1), size(y, 1))
if size(x, 1) == 1
x = Reactant.broadcast_to_size(x, (B, size(x, 2), size(x, 3)))
elseif size(y, 1) == 1
y = Reactant.broadcast_to_size(y, (B, size(y, 2), size(y, 3)))
end
end

dot_dimension_numbers = MLIR.API.stablehloDotDimensionNumbersGet(
MLIR.IR.context(), 1, [0], 1, [0], 1, [2], 1, [1]
tmp = Ops.dot_general(
T1.(materialize_traced_array(x)),
T1.(materialize_traced_array(y));
contracting_dimensions=([3], [2]),
batching_dimensions=([1], [1]),
)
set_mlir_data!(res, get_mlir_data(permutedims(tmp, (2, 3, 1))))

prec = MLIR.IR.Attribute(
MLIR.API.stablehloPrecisionAttrGet(MLIR.IR.context(), "DEFAULT")
)
tmp = TracedRArray{T1,3}(
(),
MLIR.IR.result(
MLIR.Dialects.stablehlo.dot_general(
x.mlir_data,
y.mlir_data;
result_0=resty,
dot_dimension_numbers=dot_dimension_numbers,
precision_config=prec,
),
1,
),
size(resty),
)
res.mlir_data = permutedims(tmp, (2, 3, 1)).mlir_data
return res
end

function NNlib.pad_constant(
x::TracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value
x::AnyTracedRArray{T,N}, pad::NTuple{N,Tuple{Int,Int}}, value
) where {T,N}
value = Reactant.promote_to(TracedRNumber{T}, value)
edge_padding_low = [i[1] for i in pad]
edge_padding_high = [i[2] for i in pad]
interior_padding = [0 for i in pad]
res = MLIR.IR.result(
MLIR.Dialects.stablehlo.pad(
x.mlir_data,
value.mlir_data;
edge_padding_low,
edge_padding_high,
interior_padding,
),
1,
)
return TracedRArray{T,N}((), res, size(MLIR.IR.type(res)))
low = [i[1] for i in pad]
high = [i[2] for i in pad]
interior = [0 for i in pad]
return Ops.pad(materialize_traced_array(x), value; low, high, interior)
end

# XXX: reevaluate this manual optimization once
Expand All @@ -305,7 +284,7 @@ function NNlib.gather!(
src::AnyTracedRArray{T2,2},
idxs::Union{AbstractUnitRange{<:Number}},
) where {T1,T2}
dst.mlir_data = src[:, idxs].mlir_data
set_mlir_data!(dst, get_mlir_data(src[:, idxs]))
return dst
end

Expand All @@ -314,8 +293,8 @@ function NNlib.gather!(
) where {T1,T2}
dims = NNlib.scatter_dims(src, dst, idxs)
@assert dims == 1 # scatter_dims lets us do some size checks so we call that function
idxs = (Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1).mlir_data
slice_sizes = Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]).mlir_data
idxs = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, idxs) .- 1)
slice_sizes = get_mlir_data(Reactant.promote_to(TracedRArray{Int,1}, [size(src, 1), 1]))

#! format: off
dimension_numbers = MLIR.API.stablehloGatherDimensionNumbersGet(
Expand All @@ -331,11 +310,11 @@ function NNlib.gather!(

res = MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.dynamic_gather(
src.mlir_data, idxs, slice_sizes; dimension_numbers
get_mlir_data(src), idxs, slice_sizes; dimension_numbers
),
1,
)
dst.mlir_data = res
set_mlir_data!(dst, res)
return dst
end

Expand All @@ -354,7 +333,7 @@ function NNlib.gather!(dst::TracedRArray, src::AnyTracedRArray, idxs::AbstractAr
return reshape(res, start_sizes..., :)
end
res = reshape(cat(results...; dims=(dims + 1)), size(dst))
dst.mlir_data = res.mlir_data
set_mlir_data!(dst, get_mlir_data(res))
return dst
end

Expand All @@ -363,7 +342,7 @@ dilate_shape(s, d) = max(0, 1 + d * (s - 1))
# see lax._conv_general_dilated_transpose_rhs
# https://github.com/jax-ml/jax/blob/a1dfdc1d6164ad49afb337da9effd269d430d68b/jax/_src/lax/convolution.py#L495
function NNlib.∇conv_filter!(
dw::Reactant.TracedRArray{T,N},
dw::TracedRArray{T,N},
x::AnyTracedRArray,
dy::AnyTracedRArray,
cdims::NNlib.DenseConvDims,
Expand Down Expand Up @@ -437,8 +416,8 @@ function NNlib.∇conv_filter!(

result_type = Reactant.MLIR.IR.TensorType(size(dw), Reactant.MLIR.IR.Type(T))
conv = MLIR.Dialects.stablehlo.convolution(
x.mlir_data,
dy.mlir_data;
get_mlir_data(x),
get_mlir_data(dy);
result_0=result_type,
window_strides=collect(dilation),
padding,
Expand All @@ -447,11 +426,12 @@ function NNlib.∇conv_filter!(
feature_group_count,
batch_group_count,
)

dw.mlir_data = MLIR.IR.result(conv)
set_mlir_data!(dw, MLIR.IR.result(conv))

if !NNlib.flipkernel(cdims)
dw.mlir_data = Reactant.Ops.reverse(dw; dimensions=output_spatial_dims).mlir_data
set_mlir_data!(
dw, get_mlir_data(Reactant.Ops.reverse(dw; dimensions=output_spatial_dims))
)
end

return dw
Expand Down Expand Up @@ -553,8 +533,8 @@ function NNlib.∇conv_data!(
end

conv = MLIR.Dialects.stablehlo.convolution(
dy.mlir_data,
w.mlir_data;
get_mlir_data(dy),
get_mlir_data(w);
result_0=result_type,
window_strides=1,
padding,
Expand All @@ -564,8 +544,7 @@ function NNlib.∇conv_data!(
feature_group_count,
batch_group_count=1,
)

dx.mlir_data = MLIR.IR.result(conv)
set_mlir_data!(dx, MLIR.IR.result(conv))

return dx
end
Expand Down
Loading

0 comments on commit 814e9c0

Please sign in to comment.