Skip to content

Commit

Permalink
fix: only compile non-CPU broadcasting
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 8, 2024
1 parent bfabd26 commit b2d9a9f
Showing 1 changed file with 21 additions and 0 deletions.
21 changes: 21 additions & 0 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,27 @@ end

# TODO replace this copy for `setindex!` maybe? how to copy data to already existing buffer? (i.e. `copyto!`)
function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteRArray}})
foreach(bc.args) do x
x isa ConcreteRArray && XLA.await(x.data)
end

all_on_cpu = all(bc.args) do x
x isa ConcreteRArray && return XLA.BufferOnCPU(x.data.buffer)
return true
end
if all_on_cpu
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
if !Base.isconcretetype(ElType)
throw(
ErrorException(
"`copy` on `ConcreteRArray` for non-concrete eltype is not implemented"
),
)
end
aux = copyto!(similar(Array{ElType}, axes(bc)), bc)
return ConcreteRArray(aux)
end

fn = Reactant.compile(Broadcast.BroadcastFunction(bc.f), (bc.args...,))
return fn(bc.args...)
end
Expand Down

0 comments on commit b2d9a9f

Please sign in to comment.