Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Ops.hlo_call(::String, args...) #358

Merged
merged 11 commits into from
Dec 10, 2024
Merged

Add Ops.hlo_call(::String, args...) #358

merged 11 commits into from
Dec 10, 2024

Conversation

Pangoraw
Copy link
Collaborator

No description provided.

@Pangoraw Pangoraw marked this pull request as ready for review December 10, 2024 12:29
@Pangoraw Pangoraw changed the title Add Ops.hlo_call(::String, args) Add Ops.hlo_call(::String, args...) Dec 10, 2024
@Pangoraw Pangoraw linked an issue Dec 10, 2024 that may be closed by this pull request
@Pangoraw
Copy link
Collaborator Author

Pangoraw commented Dec 10, 2024

idea: we could hash the code and just the call the function if a function with the same hash is already present in the module (preventing duplicate functions).

So that:

for i in 1:N
  x = Ops.hlo_call(code, x)
end

has the function just once (but multiple calls).


Edit: implemented this by using the code hash in the function names.

@mofeing
Copy link
Collaborator

mofeing commented Dec 10, 2024

that can be done directly with a Dict{HLOIR,Function} (or sth like that) and then calling

get!(dict, code) do code
    compile(code)
end

@Pangoraw
Copy link
Collaborator Author

Pangoraw commented Dec 10, 2024

that can be done directly with a Dict{HLOIR,Function} (or sth like that) and then calling

ah right, do we have a place to store such state within a single trace ? Otherwise my idea is to name these function like $(orig_name)_$(hash(code)) and see if it exists within the module before adding it.

@mofeing
Copy link
Collaborator

mofeing commented Dec 10, 2024

I have a question... What is hlo_call returning? A function? Wouldn't it be better to just inject the MLIR code in the current MLIR block?

@Pangoraw
Copy link
Collaborator Author

Pangoraw commented Dec 10, 2024

No, it is returning a tuple of tracedarray returned by the call to the function.
i.e. Reactant.@jit Ops.hlo_call(code, x, y) is a tuple of concrete arrays.

src/Ops.jl Outdated Show resolved Hide resolved
src/Ops.jl Outdated
function hlo_call(code, args...; location=mlir_stacktrace("hlo_call", @__FILE__, @__LINE__))
new_mod = parse(MLIR.IR.Module, code)
body = MLIR.IR.body(new_mod)
fn = MLIR.IR.first_op(body)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if the first op is not main? like what if we the code was traced by us and we added some function barriers.

maybe we can add a kwarg for selecting the target function (and default it to main), so we just iterate over the ops doing first(Iterators.filter(op -> String(IR.attr(op, "sym_name")) == target_fn, OperationIterator(body))

Copy link
Collaborator Author

@Pangoraw Pangoraw Dec 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is inside the code that was given by the caller. Currently, the expectation is that there is only one function inside the given module. We can surely revisit that with a keyword indeed, or a tuple as for Core.llvmcall

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I would instead ideally have a kwargument fn=main, and we extract that fn as the top level one

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added a func_name::String kwarg for this

src/Ops.jl Outdated Show resolved Hide resolved
test/ops.jl Outdated Show resolved Hide resolved
Pangoraw and others added 2 commits December 10, 2024 16:13
Co-authored-by: Sergio Sánchez Ramírez <[email protected]>
MLIR.IR.rmfromparent!(fn)

current_module = MLIR.IR.mmodule()
top_level_block = MLIR.IR.body(current_module)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we also need to mark all fn's as private, as well as make sure to move all fns in the module (e.g. the main function could call something)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, do you know if we can encounter ops other than func.func (maybe gpu.func in the future?) and what to do with them ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it’s fine to assume func for now but if desired we could generalize to function interface or whatnot

src/Ops.jl Show resolved Hide resolved
src/Ops.jl Outdated
new_mod = parse(MLIR.IR.Module, code)
body = MLIR.IR.body(new_mod)

for op in MLIR.IR.OperationIterator(body)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh we should import all the ops, not just func, but it’s okay to be limited to just func as the entry function. Eg if main calls a gpu function that would be fine. Or if a global constant op

y_reactant = Reactant.to_rarray(y)

@test Reactant.@jit(
Ops.hlo_call(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a test with multiple functions in the module.

and can you also add a test with two (different) hlo calls that happen to contain functions of the same name (to make sure we do the symbol rename properly)

@Pangoraw
Copy link
Collaborator Author

So hlo_call currently returns a tuple of arrays (one for each result), should we special case nresults() == 1 to return the only array (no tuple) ?

x, = Ops.hlo_call(
"""
module {
func.func @my_add(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add a version of this where the two definitions are different.

just because if we fix caching then we might not actually not emit it twice (and thus not check things)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test with the same name but different definitions:

Reactant.jl/test/ops.jl

Lines 945 to 970 in 8eb71cc

function f_multiple_hlo_calls(x, y)
x, = Ops.hlo_call(
"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.add %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x,
y,
)
return Ops.hlo_call(
"""
module {
func.func @main(%arg0: tensor<3xf32>, %arg1: tensor<3xf32>) -> tensor<3xf32> {
%0 = stablehlo.multiply %arg0, %arg1 : tensor<3xf32>
return %0 : tensor<3xf32>
}
}
""",
x,
y,
)
end

@wsmoses
Copy link
Member

wsmoses commented Dec 10, 2024

So hlo_call currently returns a tuple of arrays (one for each result), should we special case nresults() == 1 to return the only array (no tuple) ?

Honestly I think it’s better to always return the tuple. That way folks using it don’t need to special case if there are multiple returns or not

Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome stuff!

@wsmoses wsmoses merged commit 816e789 into EnzymeAD:main Dec 10, 2024
17 of 36 checks passed
@mofeing
Copy link
Collaborator

mofeing commented Dec 10, 2024

this is amazing and opens a way to do things like #354

@Pangoraw Pangoraw deleted the hlo-cal branch December 10, 2024 21:59
@Pangoraw
Copy link
Collaborator Author

Other potential applications:

  • Pytorch has pretty good stablehlo export with torch_xla nowadays.
  • ONNX via onnx-mlir

Strings really are the universal model format (we can add hlo_call(::Vector{UInt8}, ...) which takes mlir bytecode).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cannot compile function with Module as parameter
3 participants