Skip to content

Commit

Permalink
test: unbreak CUDA CI (#337)
Browse files Browse the repository at this point in the history
* chore: add compat entries

* feat: add copyto! for ConcreteRArray

* feat: `@jit` compile ConcreteRArray broadcasting

* fix: manually zero out the lower triangular and upper triangular values

* fix: only do it in tests

* feat: compile mapreduce for ConcreteRArray

* test: manual array conversion

* revert: change in Ops.cholesky

* revert: remove unnecessary changes

* fix: only compile non-CPU broadcasting

* fix: address reviewer comments

* chore: apply suggestions from code review
  • Loading branch information
avik-pal authored Dec 10, 2024
1 parent 66d6cfc commit d82b31a
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 22 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
- name: "Run CompatHelper"
run: |
import CompatHelper
CompatHelper.main()
CompatHelper.main(; subdirs=[".", "test", "lib/ReactantCore"])
shell: julia --color=yes {0}
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
Expand Down
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,10 @@ ArrayInterface = "7.10"
CEnum = "0.4, 0.5"
Downloads = "1.6"
Enzyme = "0.13.21"
EnzymeCore = "0.8.6, 0.8.7, 0.8.8"
EnzymeCore = "0.8.8"
GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
NNlib = "0.9.24"
NNlib = "0.9.26"
OrderedCollections = "1"
Preferences = "1.4"
ReactantCore = "0.1.2"
Expand Down
59 changes: 48 additions & 11 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
end

XLA.await(a.data)
if XLA.BufferOnCPU(a.data.buffer)
if buffer_on_cpu(a)
buf = a.data.buffer
GC.@preserve buf begin
ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf))
Expand Down Expand Up @@ -246,7 +246,7 @@ function Base.setindex!(a::ConcreteRArray{T}, v, args::Vararg{Int,N}) where {T,N
end

XLA.await(a.data)
if XLA.BufferOnCPU(a.data.buffer)
if buffer_on_cpu(a)
buf = a.data.buffer
GC.@preserve buf begin
ptr = Base.unsafe_convert(Ptr{T}, XLA.UnsafeBufferPointer(buf))
Expand Down Expand Up @@ -289,15 +289,52 @@ end

# TODO replace this copy for `setindex!` maybe? how to copy data to already existing buffer? (i.e. `copyto!`)
function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteRArray}})
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
if !Base.isconcretetype(ElType)
throw(
ErrorException(
"`copy` on `ConcreteRArray` for non-concrete eltype is not implemented"
),
)
for x in bc.args
x isa ConcreteRArray && XLA.await(x.data)
end

aux = copyto!(similar(Array{ElType}, axes(bc)), bc)
return ConcreteRArray(aux)
all_on_cpu = all(buffer_on_cpu, bc.args)
if all_on_cpu
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
if !Base.isconcretetype(ElType)
throw(
ErrorException(
"`copy` on `ConcreteRArray` for non-concrete eltype is not implemented"
),
)
end
aux = copyto!(similar(Array{ElType}, axes(bc)), bc)
return ConcreteRArray(aux)
end

fn = Reactant.compile(Broadcast.BroadcastFunction(bc.f), (bc.args...,))
return fn(bc.args...)
end

function Base.copyto!(dest::ConcreteRArray, src::ConcreteRArray)
dest.data = src.data
return dest
end

function Base.mapreduce(
@nospecialize(f),
@nospecialize(op),
@nospecialize(A::ConcreteRArray{T,N});
dims=:,
init=nothing,
) where {T,N}
fn = Reactant.compile(CallMapReduce(f, op, dims, init), (A,))
return fn(A)
end

struct CallMapReduce{Fn,Op,Dims,Init}
f::Fn
op::Op
dims::Dims
init::Init
end

(f::CallMapReduce)(A) = Base.mapreduce(f.f, f.op, A; f.dims, f.init)

buffer_on_cpu(::Any) = true
buffer_on_cpu(x::ConcreteRArray) = XLA.BufferOnCPU(x.data.buffer)
2 changes: 2 additions & 0 deletions src/TracedRNumber.jl
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,8 @@ for (jlop, hloop) in (
(:(Base.FastMath.exp_fast), :exponential),
(:(Base.log), :log),
(:(Base.sqrt), :sqrt),
(:(Base.ceil), :ceil),
(:(Base.floor), :floor),
)
@eval function $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T}
OutTy = $(hloop === :abs) ? real(T) : T
Expand Down
21 changes: 21 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,24 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[compat]
ArrayInterface = "7.10"
BenchmarkTools = "1.5"
Enzyme = "0.13.21"
FFTW = "1.8"
Flux = "0.15"
Functors = "0.5"
InteractiveUtils = "1.10"
LinearAlgebra = "1.10"
Lux = "1.4.1"
LuxLib = "1.3"
MLUtils = "0.4.4"
NNlib = "0.9.26"
OneHotArrays = "0.2.6"
Optimisers = "0.4"
Random = "1.10"
SafeTestsets = "0.1"
SpecialFunctions = "2.4"
Statistics = "1.10"
Test = "1.10"
22 changes: 14 additions & 8 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,14 +82,18 @@ end
end

@testset "cholesky" begin
g(x) = Ops.cholesky(x; lower=true)
# cholesky in stablehlo for the other triangle is implementation defined.
# See https://github.com/EnzymeAD/Reactant.jl/issues/338 for more details.
g1(x) = triu(Ops.cholesky(x))
g2(x) = tril(Ops.cholesky(x; lower=true))

x = ConcreteRArray([
10.0 2.0 3.0
2.0 5.0 6.0
3.0 6.0 9.0
])
@test cholesky(Array(x)).U @jit Ops.cholesky(x)
@test transpose(cholesky(Array(x)).U) @jit g(x)
@test cholesky(Array(x)).U @jit g1(x)
@test transpose(cholesky(Array(x)).U) @jit g2(x)

x = ConcreteRArray(
[
Expand All @@ -98,8 +102,9 @@ end
3.0+4.0im 3.0+2.0im 9.0+0.0im
],
)
@test cholesky(Array(x)).U @jit Ops.cholesky(x)
@test adjoint(cholesky(Array(x)).U) @jit g(x)

@test cholesky(Array(x)).U @jit g1(x)
@test adjoint(cholesky(Array(x)).U) @jit g2(x)
end

@testset "clamp" begin
Expand Down Expand Up @@ -210,13 +215,14 @@ end
]
# NOTE `LinearAlgebra.dot` is not equal to `sum(a .* b)` on complex numbers due to conjugation
@test sum(a .* b) @jit f1(a, b)
@test kron(reshape(a, length(a), 1), reshape(b, 1, length(b))) @jit fouter(a, b)
@test kron(reshape(Array(a), length(a), 1), reshape(Array(b), 1, length(b)))
@jit fouter(a, b)
@test a .* b @jit fouter_batch1(a, b)
end

a = ConcreteRArray([1 2; 3 4])
b = ConcreteRArray([5 6; -7 -8])
@test a' * b == @jit f1(a, b)
@test Array(a)' * Array(b) == @jit f1(a, b)
end

@testset "einsum" begin
Expand All @@ -239,7 +245,7 @@ end
x = reshape(a, (2, 2))
y = reshape(b, (2, 2))
@test x .* y @jit f3(x, y)
@test x * y @jit f4(x, y)
@test Array(x) * Array(y) @jit f4(x, y)
end
end

Expand Down

0 comments on commit d82b31a

Please sign in to comment.