Skip to content

Commit

Permalink
fix: ensure printing of wrapped ConcreteRArrays goes through our show (
Browse files Browse the repository at this point in the history
…#367)

* fix: ensure printing of wrapped ConcreteRArrays goes through our show

* fix: allow wrapped arrays in mapreduce
  • Loading branch information
avik-pal authored Dec 12, 2024
1 parent ea97be3 commit 8b90501
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
10 changes: 6 additions & 4 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -195,16 +195,18 @@ function Base.show(io::IO, X::ConcreteRScalar{T}) where {T}
return nothing
end

function Base.print_array(io::IO, X::ConcreteRArray)
if X.data == XLA.AsyncEmptyBuffer
function Base.print_array(io::IO, X::AnyConcreteRArray)
data = ancestor(X).data
if data == XLA.AsyncEmptyBuffer
println(io, "<Empty buffer>")
return nothing
end
return Base.print_array(io, convert(Array, X))
end

function Base.show(io::IO, X::ConcreteRArray)
if X.data == XLA.AsyncEmptyBuffer
function Base.show(io::IO, X::AnyConcreteRArray)
data = ancestor(X).data
if data == XLA.AsyncEmptyBuffer
println(io, "<Empty buffer>")
return nothing
end
Expand Down
6 changes: 6 additions & 0 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,12 @@ function Enzyme.make_zero(
return res
end

function ancestor(x::AbstractArray)
p_x = parent(x)
p_x === x && return x
return ancestor(p_x)
end

include("mlir/MLIR.jl")
include("XLA.jl")
include("Interpreter.jl")
Expand Down
7 changes: 3 additions & 4 deletions src/TracedRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,6 @@ function set_mlir_data!(x::AnyTracedRArray, data)
return x
end

ancestor(x::TracedRArray) = x
ancestor(x::WrappedTracedRArray) = ancestor(parent(x))

get_ancestor_indices(::TracedRArray, indices...) = indices
function get_ancestor_indices(x::WrappedTracedRArray, indices...)
return get_ancestor_indices(parent(x), Base.reindex(parentindices(x), indices)...)
Expand Down Expand Up @@ -388,10 +385,12 @@ end
function Base.mapreduce(
@nospecialize(f),
@nospecialize(op),
@nospecialize(A::TracedRArray{T,N});
@nospecialize(A::AnyTracedRArray{T,N});
dims=:,
init=nothing,
) where {T,N}
A = materialize_traced_array(A)

if dims isa Int
dims = [dims]
end
Expand Down

0 comments on commit 8b90501

Please sign in to comment.