From 8b90501bdc8aaac863b863886792e83b00bdcf21 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 12 Dec 2024 10:11:00 +0530 Subject: [PATCH] fix: ensure printing of wrapped ConcreteRArrays goes through our show (#367) * fix: ensure printing of wrapped ConcreteRArrays goes through our show * fix: allow wrapped arrays in mapreduce --- src/ConcreteRArray.jl | 10 ++++++---- src/Reactant.jl | 6 ++++++ src/TracedRArray.jl | 7 +++---- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/ConcreteRArray.jl b/src/ConcreteRArray.jl index ceb08440..dac67bf6 100644 --- a/src/ConcreteRArray.jl +++ b/src/ConcreteRArray.jl @@ -195,16 +195,18 @@ function Base.show(io::IO, X::ConcreteRScalar{T}) where {T} return nothing end -function Base.print_array(io::IO, X::ConcreteRArray) - if X.data == XLA.AsyncEmptyBuffer +function Base.print_array(io::IO, X::AnyConcreteRArray) + data = ancestor(X).data + if data == XLA.AsyncEmptyBuffer println(io, "") return nothing end return Base.print_array(io, convert(Array, X)) end -function Base.show(io::IO, X::ConcreteRArray) - if X.data == XLA.AsyncEmptyBuffer +function Base.show(io::IO, X::AnyConcreteRArray) + data = ancestor(X).data + if data == XLA.AsyncEmptyBuffer println(io, "") return nothing end diff --git a/src/Reactant.jl b/src/Reactant.jl index 06fd59af..0fc900b2 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -85,6 +85,12 @@ function Enzyme.make_zero( return res end +function ancestor(x::AbstractArray) + p_x = parent(x) + p_x === x && return x + return ancestor(p_x) +end + include("mlir/MLIR.jl") include("XLA.jl") include("Interpreter.jl") diff --git a/src/TracedRArray.jl b/src/TracedRArray.jl index 0d73b109..5bc3ee30 100644 --- a/src/TracedRArray.jl +++ b/src/TracedRArray.jl @@ -122,9 +122,6 @@ function set_mlir_data!(x::AnyTracedRArray, data) return x end -ancestor(x::TracedRArray) = x -ancestor(x::WrappedTracedRArray) = ancestor(parent(x)) - get_ancestor_indices(::TracedRArray, indices...) = indices function get_ancestor_indices(x::WrappedTracedRArray, indices...) return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...) @@ -388,10 +385,12 @@ end function Base.mapreduce( @nospecialize(f), @nospecialize(op), - @nospecialize(A::TracedRArray{T,N}); + @nospecialize(A::AnyTracedRArray{T,N}); dims=:, init=nothing, ) where {T,N} + A = materialize_traced_array(A) + if dims isa Int dims = [dims] end