Skip to content

Commit

Permalink
implement Ops.hlo_call
Browse files Browse the repository at this point in the history
  • Loading branch information
Pangoraw committed Dec 10, 2024
1 parent 0c4c708 commit 910d141
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 0 deletions.
68 changes: 68 additions & 0 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1046,4 +1046,72 @@ function compare(
return TracedRArray{Bool,ndims(lhs)}((), res, size(lhs))
end

"""
Ops.hlo_call(mlir_code::String, args::Vararg{AnyTracedRArray}...) -> NTuple{N, AnyTracedRArray}
Given a MLIR module given as a string and containing a single function,
calls the given function with the provided arguments and return a tuple for each result of the call.
```julia-repl
julia> Reactant.@jit(
Ops.hlo_call(
\"\"\"
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
\"\"\",
Reactant.to_rarray(Float32[1, 2, 3]),
Reactant.to_rarray(Float32[1, 2, 3]),
)
)
(ConcreteRArray{Float32, 1}(Float32[2.0, 4.0, 6.0]),)
```
"""
function hlo_call(code, args...; location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__))
new_mod = parse(MLIR.IR.Module, code)
body = MLIR.IR.body(new_mod)
fn = MLIR.IR.first_op(body)
MLIR.IR.rmfromparent!(fn)

current_module = MLIR.IR.mmodule()
top_level_block = MLIR.IR.body(current_module)

orig_name = String(MLIR.IR.attr(fn, "sym_name"))
name = orig_name * "_" * string(gensym())

push!(top_level_block, fn)

ftype_attr = MLIR.IR.attr(fn, "function_type")
ftype = MLIR.IR.Type(ftype_attr)

MLIR.IR.attr!(fn, "sym_name", MLIR.IR.Attribute(name))

@assert all(Base.Fix2(isa, Reactant.AnyTracedRArray), args) "all inputs to hlo_call should be reactant arrays"
@assert MLIR.IR.ninputs(ftype) == length(args) "invalid number of argument for function $orig_name"

operands = [a.mlir_data for a in args]
call = MLIR.Dialects.func.call(
operands;
result_0=[MLIR.IR.result(ftype, i) for i in 1:MLIR.IR.nresults(ftype)],
callee=MLIR.IR.FlatSymbolRefAttribute(name),
location,
)

return ntuple(MLIR.IR.nresults(call)) do i
out = MLIR.IR.result(call, i)
ty = MLIR.IR.type(out)
sz = MLIR.IR.size(ty)
T = MLIR.IR.julia_type(eltype(ty))
N = length(sz)
if N == 0
Reactant.TracedRNumber{T}((), out)
else
Reactant.TracedRArray{T,N}((), out, sz)
end
end
end

end # module Ops
22 changes: 22 additions & 0 deletions test/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -866,3 +866,25 @@ end
z = ConcreteRArray([1e-8, 0.001, 2.0])
@test SpecialFunctions.zeta.(Array(s), Array(z)) @jit Ops.zeta(s, z)
end

@testset "hlo_call" begin
x = Float32[1.0, 2.0, 50.0]
y = Float32[-4.0, 0.001, 2.0]
x_reactant = Reactant.to_rarray(x)
y_reactant = Reactant.to_rarray(y)

@test Reactant.@jit(
Ops.hlo_call(
"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x_reactant,
y_reactant,
)
)[1] x .+ y
end

0 comments on commit 910d141

Please sign in to comment.