diff --git a/Project.toml b/Project.toml index 472c89172..61e11d42e 100644 --- a/Project.toml +++ b/Project.toml @@ -48,7 +48,7 @@ GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" NNlib = "0.9.24" OrderedCollections = "1" -PrecompileTools = "1.2.1" +PrecompileTools = "1" Preferences = "1.4" ReactantCore = "0.1.2" Reactant_jll = "0.0.26" diff --git a/src/XLA.jl b/src/XLA.jl index 4857dc64d..af2b9539a 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -6,11 +6,12 @@ mutable struct Client client::Ptr{Cvoid} function Client(client::Ptr{Cvoid}) - return new(client) - #@assert client != C_NULL - #finalizer(new(client)) do client - # @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid - #end + @assert client != C_NULL + client = new(client) + finalizer(client) do client + @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid + end + return client end end @@ -40,13 +41,8 @@ end SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid -const cpuclientcount = Ref(0) # TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing` function CPUClient(asynchronous=false, node_id=0, num_nodes=1) - global cpuclientcount - #@assert cpuclientcount[] == 0 - cpuclientcount[] += 1 - f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient") client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes) #client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid} diff --git a/src/precompile.jl b/src/precompile.jl index 532c03361..dc09d21c1 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -1,10 +1,15 @@ using PrecompileTools: @setup_workload, @compile_workload @setup_workload begin - Reactant.__init__() - XLA.__init__() @compile_workload begin - x = Reactant.ConcreteRArray(randn(Float64, 2, 2)) - @jit sum(x) + Reactant.__init__() + cpu = XLA.CPUClient() + x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu) + @code_hlo optimize = false sum(x) + end + + @compile_workload begin + interp = Reactant.ReactantInterpreter() + Base.code_ircode(sum, (Reactant.TracedRArray{Float64,2},); interp) end end