Skip to content

Commit

Permalink
precompile first interpretation (#353)
Browse files Browse the repository at this point in the history
* precompile first try

* add __init__ & assert remove

* remove counter & fix precompilation

* keep only one workload

* try to fix CI

* forgotten init

* compact

* apply CI check inside setup_workload

* reviews

* typo

* remove CI ENV hack

* Revert "remove CI ENV hack"

This reverts commit a688556.

* test CI

* test CI 2

* de initialize opaque closure cache

* ongoing

* more attempts

* fix

* fix

* fix

* Disable rules for now as unused

---------

Co-authored-by: William Moses <[email protected]>
  • Loading branch information
glou-nes and wsmoses authored Jan 4, 2025
1 parent 103b230 commit b32441e
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 40 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
Expand Down Expand Up @@ -60,6 +61,7 @@ LinearAlgebra = "1.10"
NNlib = "0.9.26"
OffsetArrays = "1"
OrderedCollections = "1"
PrecompileTools = "1"
Preferences = "1.4"
PythonCall = "0.9"
Random = "1.10"
Expand Down
23 changes: 16 additions & 7 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -334,12 +334,18 @@ const cuFunc = Ref{UInt}(0)
const cuModule = Ref{UInt}(0)

function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true, no_nan::Bool=false)
fnwrapped,
func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results = MLIR.IR.mmodule!(mod) do
MLIR.IR.block!(MLIR.IR.body(mod)) do
return Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
end
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues

MLIR.IR.activate!(mod)
MLIR.IR.activate!(MLIR.IR.body(mod))
fnwrapped, func2, traced_result, result, seen_args, ret, linear_args, in_tys,
linear_results =
try
Reactant.TracedUtils.make_mlir_fn(f, args, (), "main", true)
finally
MLIR.IR.deactivate!(MLIR.IR.body(mod))
MLIR.IR.deactivate!(mod)
end

concrete_seen = OrderedIdDict()
Expand Down Expand Up @@ -828,7 +834,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
ctx = MLIR.IR.Context(Reactant.registry[], false)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

return MLIR.IR.context!(ctx) do
MLIR.IR.activate!(ctx)
return try
# compile function to MLIR module
mod = MLIR.IR.Module(MLIR.IR.Location())
linear_args, linear_results, preserved_args, seen_args, concrete_result, isclosure = compile_mlir!(
Expand All @@ -851,6 +858,8 @@ function compile_xla(f, args; client=nothing, optimize=true, no_nan=false)
return exec,
linear_args, linear_results, preserved_args, seen_args, concrete_result,
isclosure
finally
MLIR.IR.deactivate!(ctx)
end
end

Expand Down
12 changes: 6 additions & 6 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
function ConcreteRNumber{T}(
data::T2; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
data::T2; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing
) where {T<:Number,T2<:Number}
data = convert(T, data)
crarray = ConcreteRArray(fill(data); client, idx, device)
return ConcreteRNumber{T}(crarray.data)
end
function ConcreteRNumber(
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing
) where {T<:Number}
crarray = ConcreteRArray(fill(data); client, idx, device)
return ConcreteRNumber{T}(crarray.data)
Expand Down Expand Up @@ -37,7 +37,7 @@ end
Base.convert(::Type{T}, x::ConcreteRNumber) where {T<:Number} = convert(T, to_number(x))

function ConcreteRArray(
data::T; client=XLA.default_backend[], idx=XLA.default_device_idx[], device=nothing
data::T; client::XLA.Client=XLA.default_backend[], idx::Int=XLA.default_device_idx[], device::Union{Nothing, XLA.Device}=nothing
) where {T<:Number}
Base.depwarn(
"ConcreteRArray(data::Number) is deprecated, use ConcreteRNumber(data) instead",
Expand All @@ -52,9 +52,9 @@ Adapt.adapt_storage(::Type{T}, x::AbstractArray) where {T<:ConcreteRArray} = T(x

function ConcreteRArray(
data::Array{T,N};
client=XLA.default_backend[],
idx=XLA.default_device_idx[],
device=nothing,
client::XLA.Client=XLA.default_backend[],
idx::Int=XLA.default_device_idx[],
device::Union{Nothing, XLA.Device}=nothing,
) where {T,N}
device = device === nothing ? XLA.ClientGetDevice(client, idx) : device
return ConcreteRArray{T,N}(
Expand Down
8 changes: 4 additions & 4 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ end
ReactantCacheToken(),
REACTANT_METHOD_TABLE,
world,
true, #=forward_rules=#
true, #=reverse_rules=#
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
)
Expand All @@ -80,8 +80,8 @@ else
REACTANT_CACHE,
REACTANT_METHOD_TABLE,
world,
true, #=forward_rules=#
true, #=forward_rules=#
false, #=forward_rules=#
false, #=forward_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
)
Expand Down
66 changes: 66 additions & 0 deletions src/Precompile.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using PrecompileTools: @setup_workload, @compile_workload

function infer_sig(sig)
interp = ReactantInterpreter()

min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))

lookup_result = Reactant.lookup_world(
sig, interp.world, Core.Compiler.method_table(interp), min_world, max_world
)
match = lookup_result::Core.MethodMatch
# look up the method and code instance
mi = ccall(
:jl_specializations_get_linfo,
Ref{Core.MethodInstance},
(Any, Any, Any),
match.method,
match.spec_types,
match.sparams,
)

@static if VERSION < v"1.11"
# For older Julia versions, we vendor in some of the code to prevent
# having to build the MethodInstance twice.
result = CC.InferenceResult(mi, CC.typeinf_lattice(interp))
frame = CC.InferenceState(result, :no, interp)
@assert !isnothing(frame)
CC.typeinf(interp, frame)
ir = CC.run_passes(frame.src, CC.OptimizationState(frame, interp), result, nothing)
rt = CC.widenconst(CC.ignorelimited(result.result))
else
ir, rt = CC.typeinf_ircode(interp, mi, nothing)
end
end

@setup_workload begin
initialize_dialect()
client = XLA.CPUClient(; checkcount=false)
@compile_workload begin
# Precompilation on 1.10 hits an apparent bug: https://github.com/JuliaLang/julia/issues/56947
@static if VERSION < v"1.11"
else
# infer_sig(Tuple{typeof(Base.sum), Reactant.TracedRArray{Float64, 2}})
# infer_sig(Tuple{typeof(Base.sin), Reactant.TracedRNumber{Float64}})
x = ConcreteRNumber(2.0; client)
Reactant.compile(sin, (x,); client)

y = ConcreteRArray([2.0]; client)
Reactant.compile(Base.sum, (y,); client)
end
end
XLA.free_client(client)
client.client = C_NULL
deinitialize_dialect()
# Opaque closures capture the worldage of their compilation and thus are not relocatable
# Therefore we explicitly purge all OC's we have created here
for v in oc_capture_vec
if v isa Base.RefValue
p = Ptr{Ptr{Cvoid}}(pointer_from_objref(v))
Base.atomic_pointerset(p, C_NULL, :monotonic)
else
empty!(v)
end
end
end
15 changes: 13 additions & 2 deletions src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -228,14 +228,23 @@ end
using .Compiler: @compile, @code_hlo, @jit, traced_getfield, create_result, compile
export ConcreteRArray, ConcreteRNumber, @compile, @code_hlo, @jit, @trace

const registry = Ref{MLIR.IR.DialectRegistry}()
function __init__()
const registry = Ref{Union{Nothing,MLIR.IR.DialectRegistry}}()

function initialize_dialect()
registry[] = MLIR.IR.DialectRegistry()
@ccall MLIR.API.mlir_c.InitializeRegistryAndPasses(
registry[]::MLIR.API.MlirDialectRegistry
)::Cvoid
end

function deinitialize_dialect()
return registry[] = nothing
end

function __init__()
return initialize_dialect()
end

function set_default_backend(backend::XLA.Client)
return XLA.default_backend[] = backend
end
Expand All @@ -244,4 +253,6 @@ function set_default_backend(backend::String)
return set_default_backend(XLA.backends[backend])
end

include("Precompile.jl")

end # module
24 changes: 14 additions & 10 deletions src/TracedUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,9 @@ function make_mlir_fn(

N = length(args)
seen_args = OrderedIdDict()
traced_args = ntuple(N) do i
return Reactant.make_tracer(
traced_args = Vector{Any}(undef, N)
for i in 1:N
@inbounds traced_args[i] = Reactant.make_tracer(
seen_args,
args[i],
(:args, i),
Expand Down Expand Up @@ -166,7 +167,10 @@ function make_mlir_fn(

@assert MLIR.IR._has_block()

result = MLIR.IR.block!(fnbody) do
# Explicitly don't use block! to avoid creating a closure, which creates
# both compile-time and relocatability issues
MLIR.IR.activate!(fnbody)
result = try
for (i, arg) in enumerate(linear_args)
if construct_function_without_args
arg.mlir_data = args[i].mlir_data
Expand All @@ -177,12 +181,9 @@ function make_mlir_fn(
end
end

# TODO fix it for kwargs
#if concretein
Reactant.call_with_reactant(f, traced_args...)
#else
# f(traced_args...)
#end
finally
MLIR.IR.deactivate!(fnbody)
end

seen_results = OrderedIdDict()
Expand Down Expand Up @@ -215,7 +216,8 @@ function make_mlir_fn(

out_tys = [transpose_ty(Ops.mlir_type(arg)) for arg in linear_results]

ret = MLIR.IR.block!(fnbody) do
MLIR.IR.activate!(fnbody)
ret = try
vals = MLIR.IR.Value[]
for res in linear_results
col_maj = if res isa MissingTracedValue
Expand All @@ -230,7 +232,9 @@ function make_mlir_fn(
!no_args_in_result && @assert length(vals) == length(linear_results)

dialect = getfield(MLIR.Dialects, return_dialect)
return dialect.return_(vals)
dialect.return_(vals)
finally
MLIR.IR.deactivate!(fnbody)
end

name2 = name
Expand Down
19 changes: 10 additions & 9 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@ mutable struct Client
client::Ptr{Cvoid}

function Client(client::Ptr{Cvoid})
@assert client != C_NULL
return new(client)
#@assert client != C_NULL
#finalizer(new(client)) do client
# @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
#end
end
end

@inline function free_client(client::Client)
@ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
end

function to_row_major(x::Array{T,N}) where {T,N}
return permutedims(x, reverse(Base.OneTo(N)))
end
Expand Down Expand Up @@ -42,11 +43,11 @@ SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid

const cpuclientcount = Ref(0)
# TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing`
function CPUClient(asynchronous=false, node_id=0, num_nodes=1)
global cpuclientcount
@assert cpuclientcount[] == 0
cpuclientcount[] += 1

function CPUClient(asynchronous=false, node_id=0, num_nodes=1; checkcount=true)
if checkcount
@assert cpuclientcount[] == 0
cpuclientcount[] += 1
end
f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient")
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
#client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid}
Expand Down
5 changes: 3 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,8 @@ function make_oc_dict(
)::Core.OpaqueClosure where {FT}
key = f
if haskey(oc_captures, key)
return oc_captures[key]
oc = oc_captures[key]
oc
else
ores = ccall(
:jl_new_opaque_closure_from_code_info,
Expand Down Expand Up @@ -527,7 +528,7 @@ function call_with_reactant_generator(
# octup = Tuple{method.sig.parameters[2:end]...}
octup = Tuple{tys[2:end]...}
ocva = false

# jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right
# inner code during compilation without special handling (i.e. call_in_world_total).
# Opaque closures also require taking the function argument. We can work around the latter
Expand Down

0 comments on commit b32441e

Please sign in to comment.