Skip to content

Commit

Permalink
remove counter & fix precompilation
Browse files Browse the repository at this point in the history
  • Loading branch information
glou-nes committed Dec 10, 2024
1 parent d350f82 commit 3f04cdc
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 15 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
NNlib = "0.9.24"
OrderedCollections = "1"
PrecompileTools = "1.2.1"
PrecompileTools = "1"
Preferences = "1.4"
ReactantCore = "0.1.2"
Reactant_jll = "0.0.26"
Expand Down
16 changes: 6 additions & 10 deletions src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@ mutable struct Client
client::Ptr{Cvoid}

function Client(client::Ptr{Cvoid})
return new(client)
#@assert client != C_NULL
#finalizer(new(client)) do client
# @ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
#end
@assert client != C_NULL
client = new(client)
finalizer(client) do client
@ccall MLIR.API.mlir_c.FreeClient(client.client::Ptr{Cvoid})::Cvoid
end
return client
end
end

Expand Down Expand Up @@ -40,13 +41,8 @@ end

SetLogLevel(x) = @ccall MLIR.API.mlir_c.SetLogLevel(x::Cint)::Cvoid

const cpuclientcount = Ref(0)
# TODO synchronization when async is not working because `future` in `ConcreteRArray` is always `nothing`
function CPUClient(asynchronous=false, node_id=0, num_nodes=1)
global cpuclientcount
#@assert cpuclientcount[] == 0
cpuclientcount[] += 1

f = Libdl.dlsym(Reactant_jll.libReactantExtra_handle, "MakeCPUClient")
client = ccall(f, Ptr{Cvoid}, (UInt, Cint, Cint), asynchronous, node_id, num_nodes)
#client = @ccall MLIR.API.mlir_c.MakeCPUClient(asynchronous::UInt8, node_id::Cint, num_nodes::Cint)::Ptr{Cvoid}
Expand Down
13 changes: 9 additions & 4 deletions src/precompile.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
using PrecompileTools: @setup_workload, @compile_workload

@setup_workload begin
Reactant.__init__()
XLA.__init__()
@compile_workload begin
x = Reactant.ConcreteRArray(randn(Float64, 2, 2))
@jit sum(x)
Reactant.__init__()
cpu = XLA.CPUClient()
x = Reactant.ConcreteRArray(randn(Float64, 2, 2); client=cpu)
@code_hlo optimize = false sum(x)
end

@compile_workload begin
interp = Reactant.ReactantInterpreter()
Base.code_ircode(sum, (Reactant.TracedRArray{Float64,2},); interp)
end
end

0 comments on commit 3f04cdc

Please sign in to comment.