From ff729ce0f9f56712b38964a95e7249c765f76aaa Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Tue, 10 Dec 2024 22:40:02 -0600 Subject: [PATCH] Cleanup --- Project.toml | 5 +- ext/ReactantCUDAExt.jl | 553 ----------------------------------------- src/Reactant.jl | 3 - src/utils.jl | 84 ++++--- test/Project.toml | 1 - test/runtests.jl | 24 +- 6 files changed, 61 insertions(+), 609 deletions(-) delete mode 100644 ext/ReactantCUDAExt.jl diff --git a/Project.toml b/Project.toml index a5d243705..dd4d67325 100644 --- a/Project.toml +++ b/Project.toml @@ -21,7 +21,6 @@ Scratch = "6c6a2e73-6563-6170-7368-637461726353" [weakdeps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" YaoBlocks = "418bc28f-b43b-5e0b-a6e7-61bbc1a2c1df" @@ -32,7 +31,6 @@ path = "lib/ReactantCore" [extensions] ReactantAbstractFFTsExt = "AbstractFFTs" ReactantArrayInterfaceExt = "ArrayInterface" -ReactantCUDAExt = "CUDA" ReactantNNlibExt = "NNlib" ReactantStatisticsExt = "Statistics" ReactantYaoBlocksExt = "YaoBlocks" @@ -43,7 +41,7 @@ Adapt = "4" ArrayInterface = "7.10" CEnum = "0.4, 0.5" Downloads = "1.6" -Enzyme = "0.13.21" +Enzyme = "0.13.22" EnzymeCore = "0.8.8" GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" @@ -60,5 +58,4 @@ julia = "1.10" [extras] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" diff --git a/ext/ReactantCUDAExt.jl b/ext/ReactantCUDAExt.jl deleted file mode 100644 index ad13922f3..000000000 --- a/ext/ReactantCUDAExt.jl +++ /dev/null @@ -1,553 +0,0 @@ -module ReactantCUDAExt - -using CUDA -using Reactant: - Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber -using ReactantCore: @trace - -using Adapt - -struct CuTracedArray{T,N,A,Size} <: DenseArray{T,N} - ptr::Core.LLVMPtr{T,A} -end - - -Base.show(io::IO, a::AT) where AT <: CuTracedArray = - CUDA.Printf.@printf(io, "%s cu traced array at %p", join(size(a), '×'), Int(pointer(a))) - -## array interface - -Base.elsize(::Type{<:CuTracedArray{T}}) where {T} = sizeof(T) -Base.size(g::CuTracedArray{T,N,A,Size}) where {T,N,A,Size} = Size -Base.sizeof(x::CuTracedArray) = Base.elsize(x) * length(x) -Base.pointer(x::CuTracedArray{T,<:Any,A}) where {T,A} = Base.unsafe_convert(Core.LLVMPtr{T,A}, x) -@inline function Base.pointer(x::CuTracedArray{T,<:Any,A}, i::Integer) where {T,A} - Base.unsafe_convert(Core.LLVMPtr{T,A}, x) + Base._memory_offset(x, i) -end - - -## conversions - -Base.unsafe_convert(::Type{Core.LLVMPtr{T,A}}, x::CuTracedArray{T,<:Any,A}) where {T,A} = - x.ptr - - -## indexing intrinsics - -CUDA.@device_function @inline function arrayref(A::CuTracedArray{T}, index::Integer) where {T} - @boundscheck checkbounds(A, index) - if Base.isbitsunion(T) - arrayref_union(A, index) - else - arrayref_bits(A, index) - end -end - -@inline function arrayref_bits(A::CuTracedArray{T}, index::Integer) where {T} - unsafe_load(pointer(A), index) -end - -@inline @generated function arrayref_union(A::CuTracedArray{T,<:Any,AS}, index::Integer) where {T,AS} - typs = Base.uniontypes(T) - - # generate code that conditionally loads a value based on the selector value. - # lacking noreturn, we return T to avoid inference thinking this can return Nothing. - ex = :(Base.llvmcall("unreachable", $T, Tuple{})) - for (sel, typ) in Iterators.reverse(enumerate(typs)) - ex = quote - if selector == $(sel-1) - ptr = reinterpret(Core.LLVMPtr{$typ,AS}, data_ptr) - unsafe_load(ptr, 1) - else - $ex - end - end - end - - quote - selector_ptr = typetagdata(A, index) - selector = unsafe_load(selector_ptr) - - data_ptr = pointer(A, index) - - return $ex - end -end - -CUDA.@device_function @inline function arrayset(A::CuTracedArray{T}, x::T, index::Integer) where {T} - @boundscheck checkbounds(A, index) - if Base.isbitsunion(T) - arrayset_union(A, x, index) - else - arrayset_bits(A, x, index) - end - return A -end - -@inline function arrayset_bits(A::CuTracedArray{T}, x::T, index::Integer) where {T} - unsafe_store!(pointer(A), x, index) -end - -@inline @generated function arrayset_union(A::CuTracedArray{T,<:Any,AS}, x::T, index::Integer) where {T,AS} - typs = Base.uniontypes(T) - sel = findfirst(isequal(x), typs) - - quote - selector_ptr = typetagdata(A, index) - unsafe_store!(selector_ptr, $(UInt8(sel-1))) - - data_ptr = pointer(A, index) - - unsafe_store!(reinterpret(Core.LLVMPtr{$x,AS}, data_ptr), x, 1) - return - end -end - -CUDA.@device_function @inline function const_arrayref(A::CuTracedArray{T}, index::Integer) where {T} - @boundscheck checkbounds(A, index) - unsafe_cached_load(pointer(A), index) -end - - -## indexing - -Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear() - -Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} = - arrayref(A, i1) -Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} = - arrayset(A, convert(T,x)::T, i1) - -# preserve the specific integer type when indexing device arrays, -# to avoid extending 32-bit hardware indices to 64-bit. -Base.to_index(::CuTracedArray, i::Integer) = i - -# Base doesn't like Integer indices, so we need our own ND get and setindex! routines. -# See also: https://github.com/JuliaLang/julia/pull/42289 -Base.@propagate_inbounds Base.getindex(A::CuTracedArray, - I::Union{Integer, CartesianIndex}...) = - A[Base._to_linear_index(A, to_indices(A, I)...)] -Base.@propagate_inbounds Base.setindex!(A::CuTracedArray, x, - I::Union{Integer, CartesianIndex}...) = - A[Base._to_linear_index(A, to_indices(A, I)...)] = x - - -## const indexing - -""" - Const(A::CuTracedArray) - -Mark a CuTracedArray as constant/read-only. The invariant guaranteed is that you will not -modify an CuTracedArray for the duration of the current kernel. - -This API can only be used on devices with compute capability 3.5 or higher. - -!!! warning - Experimental API. Subject to change without deprecation. -""" -struct Const{T,N,AS} <: DenseArray{T,N} - a::CuTracedArray{T,N,AS} -end -Base.Experimental.Const(A::CuTracedArray) = Const(A) - -Base.IndexStyle(::Type{<:Const}) = IndexLinear() -Base.size(C::Const) = size(C.a) -Base.axes(C::Const) = axes(C.a) -Base.@propagate_inbounds Base.getindex(A::Const, i1::Integer) = const_arrayref(A.a, i1) - -# deprecated -Base.@propagate_inbounds ldg(A::CuTracedArray, i1::Integer) = const_arrayref(A, i1) - - -## other - -@inline function Base.iterate(A::CuTracedArray, i=1) - if (i % UInt) - 1 < length(A) - (@inbounds A[i], i + 1) - else - nothing - end -end - -function Base.reinterpret(::Type{T}, a::CuTracedArray{S,N,A}) where {T,S,N,A} - err = GPUArrays._reinterpret_exception(T, a) - err === nothing || throw(err) - - if sizeof(T) == sizeof(S) # fast case - return CuTracedArray{T,N,A}(reinterpret(Core.LLVMPtr{T,A}, a.ptr), size(a), a.maxsize) - end - - isize = size(a) - size1 = div(isize[1]*sizeof(S), sizeof(T)) - osize = tuple(size1, Base.tail(isize)...) - return CuTracedArray{T,N,A}(reinterpret(Core.LLVMPtr{T,A}, a.ptr), osize, a.maxsize) -end - - -## reshape - -function Base.reshape(a::CuTracedArray{T,M,A}, dims::NTuple{N,Int}) where {T,N,M,A} - if prod(dims) != length(a) - throw(DimensionMismatch("new dimensions (argument `dims`) must be consistent with array size (`size(a)`)")) - end - if N == M && dims == size(a) - return a - end - _derived_array(a, T, dims) -end - - - -function Adapt.adapt_storage(::CUDA.KernelAdaptor, xs::TracedRArray{T,N}) where {T,N} - res = CuTracedArray{T,N,CUDA.AS.Global, size(xs)}(Base.reinterpret(Core.LLVMPtr{T,CUDA.AS.Global}, Base.pointer_from_objref(xs))) - @show res, xs - return res -end - -const _kernel_instances = Dict{Any, Any}() - -struct LLVMFunc{F,tt} - f::Union{F, Nothing} - mod::String - image - entry::String -end - - -const GPUCompiler = CUDA.GPUCompiler -const LLVM = GPUCompiler.LLVM - - -GPULowerCPUFeaturesPass() = LLVM.NewPMModulePass("GPULowerCPUFeatures", GPUCompiler.cpu_features!) -GPULowerPTLSPass() = LLVM.NewPMModulePass("GPULowerPTLS", GPUCompiler.lower_ptls!) -GPULowerGCFramePass() = LLVM.NewPMFunctionPass("GPULowerGCFrame", GPUCompiler.lower_gc_frame!) -function noop_pass(x) - return false -end -function kern_pass(mod) - for fname in ("julia.gpu.state_getter",) - if LLVM.haskey(LLVM.functions(mod), fname) - fn = LLVM.functions(mod)[fname] - insts = LLVM.Instruction[] - for u in LLVM.uses(fn) - u = LLVM.user(u) - LLVM.replace_uses!(u, LLVM.UndefValue(LLVM.value_type(u))) - push!(insts, u) - end - for inst in insts - Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst) - end - Reactant.Enzyme.Compiler.eraseInst(mod, fn) - end - end - - return true -end -AddKernelStatePass() = LLVM.NewPMModulePass("AddKernelStatePass", kern_pass) -LowerKernelStatePass() = LLVM.NewPMFunctionPass("LowerKernelStatePass", noop_pass) -CleanupKernelStatePass() = LLVM.NewPMModulePass("CleanupKernelStatePass", noop_pass) - -# compile to executable machine code -function compile(job) - - # lower to PTX - # TODO: on 1.9, this actually creates a context. cache those. - modstr, image, entry = GPUCompiler.JuliaContext() do ctx - mod, meta = GPUCompiler.compile(:llvm, job; optimize=false, cleanup=false, validate=false) - GPUCompiler.optimize_module!(job, mod) - opt_level = 2 - tm = GPUCompiler.llvm_machine(job.config.target) - LLVM.@dispose pb=LLVM.NewPMPassBuilder() begin - LLVM.register!(pb, GPULowerCPUFeaturesPass()) - LLVM.register!(pb, GPULowerPTLSPass()) - LLVM.register!(pb, GPULowerGCFramePass()) - LLVM.register!(pb, AddKernelStatePass()) - LLVM.register!(pb, LowerKernelStatePass()) - LLVM.register!(pb, CleanupKernelStatePass()) - - LLVM.add!(pb, LLVM.NewPMModulePassManager()) do mpm - GPUCompiler.buildNewPMPipeline!(mpm, job, opt_level) - end - LLVM.run!(pb, mod, tm) - end - GPUCompiler.optimize_module!(job, mod) - LLVM.run!(CUDA.GPUCompiler.DeadArgumentEliminationPass(), mod, tm) - - - for fname in ("gpu_report_exception", "gpu_signal_exception") - if LLVM.haskey(LLVM.functions(mod), fname) - fn = LLVM.functions(mod)[fname] - insts = LLVM.Instruction[] - for u in LLVM.uses(fn) - push!(insts, LLVM.user(u)) - end - for inst in insts - Reactant.Enzyme.Compiler.eraseInst(LLVM.parent(inst), inst) - end - Reactant.Enzyme.Compiler.eraseInst(mod, fn) - end - end - - LLVM.strip_debuginfo!(mod) - modstr = string(mod) - - # This is a bit weird since we're taking a module from julia's llvm into reactant's llvm version - # it is probably safer to reparse a string using the right llvm module api, so we will do that. - - println(string(modstr)) - mmod = MLIR.IR.Module(@ccall MLIR.API.mlir_c.ConvertLLVMStrToMLIR(modstr::Cstring, MLIR.IR.context()::MLIR.API.MlirContext)::MLIR.API.MlirModule) - @show mmod - - # check if we'll need the device runtime - undefined_fs = filter(collect(CUDA.LLVM.functions(meta.ir))) do f - CUDA.LLVM.isdeclaration(f) && !CUDA.LLVM.isintrinsic(f) - end - intrinsic_fns = ["vprintf", "malloc", "free", "__assertfail", - "__nvvm_reflect" #= TODO: should have been optimized away =#] - needs_cudadevrt = !isempty(setdiff(CUDA.LLVM.name.(undefined_fs), intrinsic_fns)) - - # prepare invocations of CUDA compiler tools - ptxas_opts = String[] - nvlink_opts = String[] - ## debug flags - if Base.JLOptions().debug_level == 1 - push!(ptxas_opts, "--generate-line-info") - elseif Base.JLOptions().debug_level >= 2 - push!(ptxas_opts, "--device-debug") - push!(nvlink_opts, "--debug") - end - ## relocatable device code - if needs_cudadevrt - push!(ptxas_opts, "--compile-only") - end - - ptx = job.config.params.ptx - cap = job.config.params.cap - arch = "sm_$(cap.major)$(cap.minor)" - - # validate use of parameter memory - argtypes = filter([CUDA.KernelState, job.source.specTypes.parameters...]) do dt - !CUDA.isghosttype(dt) && !Core.Compiler.isconstType(dt) - end - param_usage = sum(sizeof, argtypes) - param_limit = 4096 - if cap >= v"7.0" && ptx >= v"8.1" - param_limit = 32764 - end - if param_usage > param_limit - msg = """Kernel invocation uses too much parameter memory. - $(Base.format_bytes(param_usage)) exceeds the $(Base.format_bytes(param_limit)) limit imposed by sm_$(cap.major)$(cap.minor) / PTX v$(ptx.major).$(ptx.minor).""" - - try - details = "\n\nRelevant parameters:" - - source_types = job.source.specTypes.parameters - source_argnames = Base.method_argnames(job.source.def) - while length(source_argnames) < length(source_types) - # this is probably due to a trailing vararg; repeat its name - push!(source_argnames, source_argnames[end]) - end - - for (i, typ) in enumerate(source_types) - if CUDA.isghosttype(typ) || Core.Compiler.isconstType(typ) - continue - end - name = source_argnames[i] - details *= "\n [$(i-1)] $name::$typ uses $(Base.format_bytes(sizeof(typ)))" - end - details *= "\n" - - if cap >= v"7.0" && ptx < v"8.1" && param_usage < 32764 - details *= "\nNote: use a newer CUDA to support more parameters on your device.\n" - end - - msg *= details - catch err - @error "Failed to analyze kernel parameter usage; please file an issue with a reproducer." - end - error(msg) - end - - # compile to machine code - # NOTE: we use tempname since mktemp doesn't support suffixes, and mktempdir is slow - ptx_input = tempname(cleanup=false) * ".ptx" - ptxas_output = tempname(cleanup=false) * ".cubin" - write(ptx_input, asm) - - # we could use the driver's embedded JIT compiler, but that has several disadvantages: - # 1. fixes and improvements are slower to arrive, by using `ptxas` we only need to - # upgrade the toolkit to get a newer compiler; - # 2. version checking is simpler, we otherwise need to use NVML to query the driver - # version, which is hard to correlate to PTX JIT improvements; - # 3. if we want to be able to use newer (minor upgrades) of the CUDA toolkit on an - # older driver, we should use the newer compiler to ensure compatibility. - append!(ptxas_opts, [ - "--verbose", - "--gpu-name", arch, - "--output-file", ptxas_output, - ptx_input - ]) - proc, log = CUDA.run_and_collect(`$(CUDA.ptxas()) $ptxas_opts`) - log = strip(log) - if !success(proc) - reason = proc.termsignal > 0 ? "ptxas received signal $(proc.termsignal)" : - "ptxas exited with code $(proc.exitcode)" - msg = "Failed to compile PTX code ($reason)" - msg *= "\nInvocation arguments: $(join(ptxas_opts, ' '))" - if !isempty(log) - msg *= "\n" * log - end - msg *= "\nIf you think this is a bug, please file an issue and attach $(ptx_input)" - if parse(Bool, get(ENV, "BUILDKITE", "false")) - run(`buildkite-agent artifact upload $(ptx_input)`) - end - error(msg) - elseif !isempty(log) - @debug "PTX compiler log:\n" * log - end - rm(ptx_input) - - # link device libraries, if necessary - # - # this requires relocatable device code, which prevents certain optimizations and - # hurts performance. as such, we only do so when absolutely necessary. - # TODO: try LTO, `--link-time-opt --nvvmpath /opt/cuda/nvvm`. - # fails with `Ignoring -lto option because no LTO objects found` - if needs_cudadevrt - nvlink_output = tempname(cleanup=false) * ".cubin" - append!(nvlink_opts, [ - "--verbose", "--extra-warnings", - "--arch", arch, - "--library-path", dirname(libcudadevrt), - "--library", "cudadevrt", - "--output-file", nvlink_output, - ptxas_output - ]) - proc, log = run_and_collect(`$(CUDA.nvlink()) $nvlink_opts`) - log = strip(log) - if !success(proc) - reason = proc.termsignal > 0 ? "nvlink received signal $(proc.termsignal)" : - "nvlink exited with code $(proc.exitcode)" - msg = "Failed to link PTX code ($reason)" - msg *= "\nInvocation arguments: $(join(nvlink_opts, ' '))" - if !isempty(log) - msg *= "\n" * log - end - msg *= "\nIf you think this is a bug, please file an issue and attach $(ptxas_output)" - error(msg) - elseif !isempty(log) - @debug "PTX linker info log:\n" * log - end - rm(ptxas_output) - - image = read(nvlink_output) - rm(nvlink_output) - else - image = read(ptxas_output) - rm(ptxas_output) - end - - modstr, image, meta.entry - end - LLVMFunc{job.source.specTypes.parameters[1],job.source.specTypes}(nothing, modstr, image, CUDA.LLVM.name(entry)) -end - -# link into an executable kernel -function link(job, compiled) - # load as an executable kernel object - return compiled -end - -function transpose_val(val) - attr = MLIR.IR.DenseArrayAttribute( - Int64[reverse(0:(length(size(MLIR.IR.type(val))) - 1))...] - ) - return MLIR.IR.result(MLIR.Dialects.stablehlo.transpose(val; permutation=attr), 1) -end - -function (func::LLVMFunc{F,tt})(args...; convert=Val(false), blocks::CuDim=1, threads::CuDim=1, - cooperative::Bool=false, shmem::Integer=0, call_kwargs...) where{F, tt} - @show args - @show call_kwargs - - blockdim = CUDA.CuDim3(blocks) - threaddim = CUDA.CuDim3(threads) - - mlir_args = MLIR.IR.Value[] - restys = MLIR.IR.Type[] - aliases = MLIR.IR.Attribute[] - rarrays = TracedRArray[] - for (i, a) in enumerate(args) - @show a - @assert a isa CuTracedArray - ta = Base.unsafe_pointer_to_objref(Base.reinterpret(Ptr{Cvoid}, a.ptr))::TracedRArray - push!(rarrays, ta) - arg = ta.mlir_data - arg = transpose_val(arg) - @show arg - push!(restys, MLIR.IR.type(arg)) - push!(mlir_args, arg) - push!(aliases, - MLIR.IR.Attribute(MLIR.API.stablehloOutputOperandAliasGet( - MLIR.IR.context(), - length(args) == 1 ? 0 : 1, - length(args) == 1 ? C_NULL : Ref{Int64}(i-1), - i-1, - 0, - C_NULL - )) - ) - end - - output_operand_aliases=MLIR.IR.Attribute(aliases) - call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute("configstr")) - # call = MLIR.Dialects.stablehlo.custom_call(mlir_args; result_0=restys, call_target_name="reactant_gpu_call", output_operand_aliases, backend_config=MLIR.IR.Attribute(func.mod)) - for (i, res) in enumerate(rarrays) - res.mlir_data = transpose_val(MLIR.IR.result(call, i)) - end - #CUDA.cuLaunchKernel(f, - # blockdim.x, blockdim.y, blockdim.z, - # threaddim.x, threaddim.y, threaddim.z, - # shmem, stream, kernelParams, C_NULL) -end - -# cache of compilation caches, per context -const _compiler_caches = Dict{MLIR.IR.Context, Dict{Any, LLVMFunc}}(); -function compiler_cache(ctx::MLIR.IR.Context) - cache = get(_compiler_caches, ctx, nothing) - if cache === nothing - cache = Dict{Any, LLVMFunc}() - _compiler_caches[ctx] = cache - end - return cache -end - -Reactant.@reactant_override function CUDA.cufunction(f::F, tt::TT=Tuple{}; kwargs...) where {F,TT} - @show "recufunction", f, tt - res = Base.@lock CUDA.cufunction_lock begin - # compile the function - cache = compiler_cache(MLIR.IR.context()) - source = CUDA.methodinstance(F, tt) - - # cuda = CUDA.active_state() - device = nothing # cuda.device - # config = CUDA.compiler_config(device; kwargs...)::CUDA.CUDACompilerConfig - cuda_cap=v"5.0" - cuda_ptx=v"6.3" - llvm_cap=v"5.0" - llvm_ptx=v"6.3" - kernel=true - always_inline=false - name=nothing - debuginfo=false - config = CUDA.CompilerConfig(CUDA.PTXCompilerTarget(; cap=llvm_cap, ptx=llvm_ptx, debuginfo), CUDA.CUDACompilerParams(; cap=cuda_cap, ptx=cuda_ptx); kernel, name, always_inline) - CUDA.GPUCompiler.cached_compilation(cache, source, config, compile, link) - end - res -end - -function __init__() - -end - -end # module ReactantCUDAExt diff --git a/src/Reactant.jl b/src/Reactant.jl index 1623503df..06fd59aff 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -124,7 +124,4 @@ function set_default_backend(backend::String) return set_default_backend(XLA.backends[backend]) end -# include("../ext/ReactantCUDAExt.jl") - end # module - diff --git a/src/utils.jl b/src/utils.jl index db1de4505..c165ac402 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -60,6 +60,8 @@ end function rewrite_inst(inst, ir) if Meta.isexpr(inst, :call) + # Even if type unstable we do not want (or need) to replace intrinsic + # calls or builtins with our version. ft = Core.Compiler.widenconst(maybe_argextype(inst.args[1], ir)) if !(ft <: Core.IntrinsicFunction) && !(ft <: Core.Builtin) rep = Expr(:call, call_with_reactant, inst.args...) @@ -74,6 +76,8 @@ end const REDUB_ARGUMENTS_NAME = gensym("redub_arguments") +# From Julia's Base.Meta with fix from https://github.com/JuliaLang/julia/pull/56787 +# and additionally adds support for an argument rewriting into a slot function arg_partially_inline!(code::Vector{Any}, slot_replacements::Vector{Any}, @nospecialize(type_signature)#=::Type{<:Tuple}=#, static_param_values::Vector{Any}, @@ -218,12 +222,35 @@ function _arg_partially_inline!(@nospecialize(x), slot_replacements::Vector{Any} return x end + +""" + Reactant.REDUB_ARGUMENTS_NAME + +The variable name bound to `call_with_reactant`'s tuple of arguments in its +`@generated` method definition. + +This binding can be used to manually reference/destructure `call_with_reactants` arguments + +This is required because user arguments could have a name which clashes with whatever name we choose for +our argument. Thus we gensym to create it. + +This originates from https://github.com/JuliaLabs/Cassette.jl/blob/c29b237c1ec0deda3a1037ec519eebe216952bfe/src/overdub.jl#L154 +""" +const OVERDUB_ARGUMENTS_NAME = gensym("overdub_arguments") + +# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter +# In particular this entails two pieces: +# 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance +# 2) Post type inference (using of course the reactant interpreter), all type unstable call functions are +# replaced with calls to `call_with_reactant`. This allows us to circumvent long standing issues in Julia +# using a custom interpreter in type unstable code. +# `redub_arguments` is `(typeof(original_function), map(typeof, original_args_tuple)...)` function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @nospecialize(redub_arguments)) @nospecialize args = redub_arguments - stub = Core.GeneratedFunctionStub(identity, Core.svec(:call_with_reactant, :redub_arguments), Core.svec()) + stub = Core.GeneratedFunctionStub(identity, Core.svec(:call_with_reactant, OVERDUB_ARGUMENTS_NAME), Core.svec()) # look up the method match builtin_error = :(throw(AssertionError("Unsupported call_with_reactant of builtin $redub_arguments"))) @@ -248,6 +275,7 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, matches = lookup_result.matches + # No method could be found (including in our method table), bail with an error if length(matches) != 1 return stub(world, source, method_error) end @@ -269,18 +297,15 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, @assert Core.Compiler.is_inferred(frame) method = match.method - @show mi - @show method + # The original julia code (on 1.11+) has the potential constprop, for now + # we assume this outermost function does not constprop, for ease. #if Core.Compiler.result_is_constabi(interp, frame.result) # rt = frame.result.result::Core.Compiler.Const # src = Core.Compiler.codeinfo_for_const(interp, frame.linfo, rt.val) #else opt = Core.Compiler.OptimizationState(frame, interp) - @show Core.Compiler.retrieve_code_info(mi, world) - @show opt.src - caller = frame.result @static if VERSION < v"1.11-" ir = Core.Compiler.run_passes(opt.src, opt, caller) @@ -289,11 +314,13 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, Core.Compiler.ipo_dataflow_analysis!(interp, ir, caller) end - @show ir + # Rewrite type unstable calls to recurse into call_with_reactant to ensure + # they continue to use our interpreter. Reset the derived return type + # to Any if our interpreter would change the return type of any result. + # Also rewrite invoke (type stable call) to be :call, since otherwise apparently + # screws up type inference after this (TODO this should be fixed). any_changed = false - for (i, inst) in enumerate(ir.stmts) - - + for (i, inst) in enumerate(ir.stmts) @static if VERSION < v"1.11" changed, next = rewrite_inst(inst[:inst], ir) Core.Compiler.setindex!(ir.stmts[i], next, :inst) @@ -307,10 +334,12 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, end end Core.Compiler.finish(interp, opt, ir, caller) - @show "post", ir src = Core.Compiler.ir_to_codeinf!(opt) - @show any_changed, src + # Julia hits various internal errors trying to re-perform type inference + # on type infered code (that we undo inference of), if there is no type unstable + # code to be rewritten, just use the default methodinstance (still using our methodtable), + # to improve compatibility as these bugs are fixed upstream. if !any_changed src = Core.Compiler.retrieve_code_info(mi, world) @show "post non change", src @@ -322,15 +351,16 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, signature = sig is_invoke = args[1] === typeof(Core.invoke) - # propagate edge metadata + # propagate edge metadata, this method is invalidated if the original function we are calling + # is invalidated code_info.edges = Core.MethodInstance[mi] code_info.min_world = lookup_result.valid_worlds.min_world code_info.max_world = lookup_result.valid_worlds.max_world - code_info.slotnames = Any[:call_with_reactant, :redub_arguments, code_info.slotnames...] + # Rewrite the arguments to this function, to prepend the two new arguments, the function :call_with_reactant, + # and the REDUB_ARGUMENTS_NAME tuple of input arguments + code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME, code_info.slotnames...] code_info.slotflags = UInt8[0x00, 0x00, code_info.slotflags...] - #code_info.slotnames = Any[:call_with_reactant, REDUB_ARGUMENTS_NAME] #code_info.slotnames...] - #code_info.slotflags = UInt8[0x00, 0x00] # code_info.slotflags...] n_prepended_slots = 2 overdub_args_slot = Core.SlotNumber(n_prepended_slots) @@ -339,6 +369,9 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, # the end of the pass, we'll reset `code_info` fields accordingly. overdubbed_code = Any[] overdubbed_codelocs = Int32[] + + # Rewire the arguments from our tuple input of fn and args, to the corresponding calling convention + # required by the base method. # destructure the generated argument slots into the overdubbed method's argument slots. n_actual_args = fieldcount(signature) @@ -384,24 +417,12 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, push!(fn_args, Core.SSAValue(length(overdubbed_code))) end - @show code_info.code - #=== finish initialization of `overdubbed_code`/`overdubbed_codelocs` ===# - # substitute static parameters, offset slot numbers by number of added slots, and # offset statement indices by the number of additional statements arg_partially_inline!(code_info.code, fn_args, method.sig, Any[static_params...], n_prepended_slots, n_prepended_slots, length(overdubbed_code), :propagate) - #callexpr = Expr(:call, Core.OpaqueClosure(ir), fn_args...) - #push!(overdubbed_code, callexpr) - #push!(overdubbed_codelocs, code_info.codelocs[1]) - - #push!(new_ci.code, Core.Compiler.ReturnNode(Core.SSAValue(length(overdubbed_code)))) - #push!(overdubbed_codelocs, code_info.codelocs[1]) - - # original_code_start_index = length(overdubbed_code) + 1 - append!(overdubbed_code, code_info.code) append!(overdubbed_codelocs, code_info.codelocs) @@ -416,12 +437,10 @@ function call_with_reactant_generator(world::UInt, source::LineNumberNode, self, code_info.ssavaluetypes = length(overdubbed_code) code_info.ssaflags = [0x00 for _ in 1:length(overdubbed_code)] # XXX we need to copy flags that are set for the original code - @show code_info - return code_info end -@eval function call_with_reactant(redub_arguments...) +@eval function call_with_reactant($OVERDUB_ARGUMENTS_NAME...) $(Expr(:meta, :generated_only)) $(Expr(:meta, :generated, call_with_reactant_generator)) end @@ -517,9 +536,6 @@ function make_mlir_fn( end end - interp = ReactantInterpreter() - - # TODO replace with `Base.invoke_within` if julia#52964 lands # TODO fix it for kwargs if f === Reactant.apply call_with_reactant(f, traced_args[1], (traced_args[2:end]...,)) diff --git a/test/Project.toml b/test/Project.toml index 9956337ea..4b50a487f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,7 +1,6 @@ [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" FFTW = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" diff --git a/test/runtests.jl b/test/runtests.jl index 87e1a3702..fddc963ce 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -41,8 +41,6 @@ end const REACTANT_TEST_GROUP = lowercase(get(ENV, "REACTANT_TEST_GROUP", "all")) -include("cuda.jl") -@static if false @testset "Reactant.jl Tests" begin if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "core" @safetestset "Layout" include("layout.jl") @@ -62,19 +60,17 @@ include("cuda.jl") if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "integration" @safetestset "Linear Algebra" include("integration/linear_algebra.jl") - @safetestset "CUDA" include("cuda.jl") @safetestset "AbstractFFTs" include("integration/fft.jl") end - # if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" - # @testset "Neural Networks" begin - # @safetestset "NNlib Primitives" include("nn/nnlib.jl") - # @safetestset "Flux.jl Integration" include("nn/flux.jl") - # if Sys.islinux() - # @safetestset "LuxLib Primitives" include("nn/luxlib.jl") - # @safetestset "Lux Integration" include("nn/lux.jl") - # end - # end - # end -end + if REACTANT_TEST_GROUP == "all" || REACTANT_TEST_GROUP == "neural_networks" + @testset "Neural Networks" begin + @safetestset "NNlib Primitives" include("nn/nnlib.jl") + @safetestset "Flux.jl Integration" include("nn/flux.jl") + if Sys.islinux() + @safetestset "LuxLib Primitives" include("nn/luxlib.jl") + @safetestset "Lux Integration" include("nn/lux.jl") + end + end + end end