Skip to content

Commit

Permalink
Add Ops.hlo_call(::String, args...) (#358)
Browse files Browse the repository at this point in the history
* special case String and Module in make_tracer

* implement Ops.hlo_call

* formatting

* Update src/Ops.jl

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>

* SymbolTable: fix lookup

* cache and more validation, also specify name to call

* error if not func.func

* only do special things for func.func

* symbol_rename

* add multiple call test

* rename then remove from parsed module

---------

Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
  • Loading branch information
Pangoraw and mofeing authored Dec 10, 2024
1 parent e40d715 commit 816e789
Show file tree
Hide file tree
Showing 4 changed files with 249 additions and 3 deletions.
123 changes: 123 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1046,4 +1046,127 @@ function compare(
return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs))
end

# Generate a unique name given a module hash and a function name.
function _hlo_call_name(orig_name, module_suffix)
return orig_name * "_hlo_call_" * module_suffix
end

"""
Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...; func_name::String="main") -> NTuple{N, AnyTracedRArray}
Given a MLIR module given as a string, calls the function identified by the `func_name` keyword parameter (default "main")
with the provided arguments and return a tuple for each result of the call.
```julia-repl
julia> Reactant.@jit(
Ops.hlo_call(
\"\"\"
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
\"\"\",
Reactant.to_rarray(Float32[1, 2, 3]),
Reactant.to_rarray(Float32[1, 2, 3]),
)
)
(ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
```
"""
function hlo_call(
code,
args...;
func_name="main",
location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__),
)
module_suffix = string(hash(code); base=16)
name_to_call = _hlo_call_name(func_name, module_suffix)

current_module = MLIR.IR.mmodule()
top_level_block = MLIR.IR.body(current_module)

symbol_attr_name = String(MLIR.API.mlirSymbolTableGetSymbolAttributeName())

fn = MLIR.IR.lookup(
MLIR.IR.SymbolTable(MLIR.IR.Operation(current_module)), name_to_call
)
if isnothing(fn)
new_mod = parse(MLIR.IR.Module, code)
new_mod_op = MLIR.IR.Operation(new_mod)
body = MLIR.IR.body(new_mod)

operations = collect(MLIR.IR.OperationIterator(body))
for op in operations
if MLIR.IR.name(op) == "func.func"
fn_name = String(MLIR.IR.attr(op, symbol_attr_name))
if fn_name == func_name
fn = op
end

new_name = _hlo_call_name(fn_name, module_suffix)
res = MLIR.IR.LogicalResult(
MLIR.API.mlirSymbolTableReplaceAllSymbolUses(
fn_name, new_name, new_mod_op
),
)
@assert res == MLIR.IR.success() "hlo_call: failed to rename $fn_name"

# Set function private
MLIR.IR.attr!(
op,
MLIR.API.mlirSymbolTableGetVisibilityAttributeName(),
MLIR.IR.Attribute("private"),
)

# Change function name
MLIR.IR.attr!(op, symbol_attr_name, MLIR.IR.Attribute(new_name))
end
end

for op in operations
MLIR.IR.rmfromparent!(op)
push!(top_level_block, op)
end
end

if isnothing(fn)
error("hlo_call: could not find function $func_name in the provided module")
end

ftype_attr = MLIR.IR.attr(fn, "function_type")
ftype = MLIR.IR.Type(ftype_attr)

@assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "hlo_call: all inputs to hlo_call should be reactant arrays"
@assert MLIR.IR.ninputs(ftype) == length(args) "hlo_call: invalid number of arguments for function $func_name"

for (i, arg) in enumerate(args)
expected_type = MLIR.IR.input(ftype, i)
arg_type = MLIR.IR.type(arg.mlir_data)
@assert expected_type == arg_type "hlo_call: argument #$i has the wrong type (expected $expected_type, got $arg_type)"
end

operands = [a.mlir_data for a in args]
call = MLIR.Dialects.func.call(
operands;
result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)],
callee=MLIR.IR.FlatSymbolRefAttribute(name_to_call),
location,
)

return ntuple(MLIR.IR.nresults(call)) do i
out = MLIR.IR.result(call, i)
ty = MLIR.IR.type(out)
sz = MLIR.IR.size(ty)
T = MLIR.IR.julia_type(eltype(ty))
N = length(sz)
if N == 0
Reactant.TracedRNumber{T}((), out)
else
Reactant.TracedRArray{T,N}((), out, sz)
end
end
end

end # module Ops
4 changes: 4 additions & 0 deletions src/Tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,10 @@ function make_tracer(
@assert Base.isconcretetype(RT)
nf = fieldcount(RT)

if TT === Module || TT === String
return prev
end

if ismutabletype(TT)
y = ccall(:jl_new_struct_uninit, Any, (Any,), TT)
seen[prev] = y
Expand Down
14 changes: 11 additions & 3 deletions src/mlir/IR/SymbolTable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,17 @@ Base.convert(::Core.Type{API.MlirSymbolTable}, st::SymbolTable) = st.st
Looks up a symbol with the given name in the given symbol table and returns the operation that corresponds to the symbol.
If the symbol cannot be found, returns a null operation.
"""
lookup(st::SymbolTable, name::AbstractString) =
Operation(API.mlirSymbolTableLookup(st, name))
Base.getindex(st::SymbolTable, name::AbstractString) = lookup(st, name)
function lookup(st::SymbolTable, name::AbstractString)
raw_op = API.mlirSymbolTableLookup(st, name)
if raw_op.ptr == C_NULL
nothing
else
Operation(raw_op, false)
end
end
function Base.getindex(st::SymbolTable, name::AbstractString)
@something(lookup(st, name), throw(KeyError(name)))
end

"""
push!(symboltable, operation)
Expand Down
111 changes: 111 additions & 0 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -866,3 +866,114 @@ end
z = ConcreteRArray([1e-8, 0.001, 2.0])
@test SpecialFunctions.zeta.(Array(s), Array(z)) @jit Ops.zeta(s, z)
end

@testset "hlo_call" begin
x = Float32[1.0, 2.0, 50.0]
y = Float32[-4.0, 0.001, 2.0]
x_reactant = Reactant.to_rarray(x)
y_reactant = Reactant.to_rarray(y)

@test Reactant.@jit(
Ops.hlo_call(
"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x_reactant,
y_reactant,
)
)[1] x .+ y
end

function f_repeat(x, y)
for _ in 1:3
x, = Ops.hlo_call(
"""
module {
func.func @my_add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x,
y;
func_name="my_add",
)
end
return x
end

@testset "hlo_call: repeat" begin
x = Reactant.to_rarray(randn(Float32, 3))
y = Reactant.to_rarray(randn(Float32, 3))
mod = Reactant.@code_hlo optimize = false f_repeat(x, y)
hlo_ir = repr(mod)

add_pos = findfirst("stablehlo.add", hlo_ir)
@test !isnothing(add_pos)

add_pos = findfirst("stablehlo.add", hlo_ir[last(add_pos):end])
@test isnothing(add_pos)
end

@testset "hlo_call: multiple functions" begin
@test Reactant.@jit(
Ops.hlo_call(
"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = func.call @add(%arg0, %arg1) : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32>
return %0 : tensor<3xf32>
}
func.func @add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
Reactant.to_rarray(Float32[1, 2, 3]),
Reactant.to_rarray(Float32[1, 2, 3]),
)
)[1] Float32[2, 4, 6]
end

function f_multiple_hlo_calls(x, y)
x, = Ops.hlo_call(
"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x,
y,
)
return Ops.hlo_call(
"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.multiply %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x,
y,
)
end

@testset "hlo_call: multiple hlo_calls" begin
x = Float32[1.0, 2.0, 50.0]
y = Float32[-4.0, 0.001, 2.0]
x_reactant = Reactant.to_rarray(x)
y_reactant = Reactant.to_rarray(y)

@test Reactant.@jit(f_multiple_hlo_calls(x_reactant, y_reactant))[1] (x .+ y) .* y
end

0 comments on commit 816e789

Please sign in to comment.