Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: kernels #314

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353"
[weakdeps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df"

[extensions]
ReactantAbstractFFTsExt = "AbstractFFTs"
ReactantArrayInterfaceExt = "ArrayInterface"
ReactantCUDAExt = "CUDA"
ReactantNNlibExt = "NNlib"
ReactantStatisticsExt = "Statistics"
ReactantYaoBlocksExt = "YaoBlocks"
Expand All @@ -54,7 +56,8 @@ julia = "1.10"
[extras]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"

[sources]
ReactantCore = { path = "lib/ReactantCore" }
[sources.ReactantCore]
path = "lib/ReactantCore"
81 changes: 81 additions & 0 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
module ReactantCUDAExt

using CUDA
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber
using ReactantCore: @trace


const _kernel_instances = Dict{Any, Any}()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
const _kernel_instances = Dict{Any, Any}()
const _kernel_instances = Dict{Any,Any}()


function recufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT}
cuda = CUDA.active_state()

F2 = Reactant.traced_type(F, (), Val(Reactant.TracedToConcrete))
tt2 = Reactant.traced_type(tt, (), Val(Reactant.TracedToConcrete))


Base.@lock CUDA.cufunction_lock begin
# compile the function
cache = CUDA.compiler_cache(cuda.context)
source = CUDA.methodinstance(F2, tt2)
config = CUDA.compiler_config(cuda.device; kwargs...)::CUDA.CUDACompilerConfig
fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, CUDA.compile, CUDA.link)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
fun = CUDA.GPUCompiler.cached_compilation(cache, source, config, CUDA.compile, CUDA.link)
fun = CUDA.GPUCompiler.cached_compilation(
cache, source, config, CUDA.compile, CUDA.link
)


@show fun
@show fun.mod
# create a callable object that captures the function instance. we don't need to think
Comment on lines +26 to +28
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@show fun
@show fun.mod
# create a callable object that captures the function instance. we don't need to think
@show fun
@show fun.mod
# create a callable object that captures the function instance. we don't need to think

# about world age here, as GPUCompiler already does and will return a different object
key = (objectid(source), hash(fun), f)
kernel = get(_kernel_instances, key, nothing)
if kernel === nothing
# create the kernel state object
state = CUDA.KernelState(create_exceptions!(fun.mod), UInt32(0))

kernel = CUDA.HostKernel{F,tt}(f, fun, state)
_kernel_instances[key] = kernel
end
return kernel::CUDA.HostKernel{F,tt}
end
end

const CC = Core.Compiler

import Core.Compiler:
AbstractInterpreter,
abstract_call,
abstract_call_known,
ArgInfo,
StmtInfo,
AbsIntState,
get_max_methods,
CallMeta,
Effects,
NoCallInfo,
widenconst,
mapany,
MethodResultPure


Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

function Reactant.set_reactant_abi(
interp,
f::typeof(CUDA.cufunction),
arginfo::ArgInfo,
si::StmtInfo,
sv::AbsIntState,
max_methods::Int=get_max_methods(interp, f, sv),
)
(; fargs, argtypes) = arginfo

arginfo2 = ArgInfo(
if fargs isa Nothing
nothing
else
[:($(recufunction)), fargs[2:end]...]
end,
[Core.Const(recufunction), argtypes[2:end]...],
)
return abstract_call_known(interp, recufunction, arginfo2, si, sv, max_methods)
Comment on lines +70 to +79
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
arginfo2 = ArgInfo(
if fargs isa Nothing
nothing
else
[:($(recufunction)), fargs[2:end]...]
end,
[Core.Const(recufunction), argtypes[2:end]...],
)
return abstract_call_known(interp, recufunction, arginfo2, si, sv, max_methods)
arginfo2 = ArgInfo(
if fargs isa Nothing
nothing
else
[:($(recufunction)), fargs[2:end]...]
end,
[Core.Const(recufunction), argtypes[2:end]...],
)
return abstract_call_known(interp, recufunction, arginfo2, si, sv, max_methods)

end

end # module ReactantCUDAExt
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Expand Down
23 changes: 23 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using Reactant
using Test
using CUDA

function square_kernel!(x)
i = threadIdx().x
x[i] *= x[i]
sync_threads()
return nothing
end

# basic squaring on GPU
function square!(x)
@cuda blocks = 1 threads = length(x) square_kernel!(x)
return nothing
end

@testset "Square Kernel" begin
oA = collect(1:1:64)
A = Reactant.to_rarray(oA)
func = @compile square!(A)
@test all(A .≈ (oA .* oA))
end
59 changes: 30 additions & 29 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,35 +42,36 @@ end
const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all"))

@testset "Reactant.jl Tests" begin
if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core"
@safetestset "Layout" include("layout.jl")
@safetestset "Tracing" include("tracing.jl")
@safetestset "Basic" include("basic.jl")
@safetestset "Autodiff" include("autodiff.jl")
@safetestset "Complex" include("complex.jl")
@safetestset "Broadcast" include("bcast.jl")
@safetestset "Struct" include("struct.jl")
@safetestset "Closure" include("closure.jl")
@safetestset "Compile" include("compile.jl")
@safetestset "Buffer Donation" include("buffer_donation.jl")
@safetestset "Shortcuts to MLIR ops" include("ops.jl")
@safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
@safetestset "Control Flow" include("control_flow.jl")
@safetestset "Linear Algebra" include("linear_algebra.jl")
end
#if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core"
@safetestset "CUDA" include("cuda.jl")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
@safetestset "CUDA" include("cuda.jl")
@safetestset "CUDA" include("cuda.jl")

# @safetestset "Layout" include("layout.jl")
# @safetestset "Tracing" include("tracing.jl")
# @safetestset "Basic" include("basic.jl")
# @safetestset "Autodiff" include("autodiff.jl")
# @safetestset "Complex" include("complex.jl")
# @safetestset "Broadcast" include("bcast.jl")
# @safetestset "Struct" include("struct.jl")
# @safetestset "Closure" include("closure.jl")
# @safetestset "Compile" include("compile.jl")
# @safetestset "Buffer Donation" include("buffer_donation.jl")
# @safetestset "Shortcuts to MLIR ops" include("ops.jl")
# @safetestset "Wrapped Arrays" include("wrapped_arrays.jl")
# @safetestset "Control Flow" include("control_flow.jl")
# @safetestset "Linear Algebra" include("linear_algebra.jl")
# end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
@safetestset "AbstractFFTs" include("integration/fft.jl")
end
# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration"
# @safetestset "AbstractFFTs" include("integration/fft.jl")
# end

if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
@testset "Neural Networks" begin
@safetestset "NNlib Primitives" include("nn/nnlib.jl")
@safetestset "Flux.jl Integration" include("nn/flux.jl")
if Sys.islinux()
@safetestset "LuxLib Primitives" include("nn/luxlib.jl")
@safetestset "Lux Integration" include("nn/lux.jl")
end
end
end
# if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks"
# @testset "Neural Networks" begin
# @safetestset "NNlib Primitives" include("nn/nnlib.jl")
# @safetestset "Flux.jl Integration" include("nn/flux.jl")
# if Sys.islinux()
# @safetestset "LuxLib Primitives" include("nn/luxlib.jl")
# @safetestset "Lux Integration" include("nn/lux.jl")
# end
# end
# end
end
Loading