Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@trace function_call() to introduce function barrier #346

Open
Pangoraw opened this issue Dec 9, 2024 · 6 comments · May be fixed by #366
Open

@trace function_call() to introduce function barrier #346

Pangoraw opened this issue Dec 9, 2024 · 6 comments · May be fixed by #366

Comments

@Pangoraw
Copy link
Collaborator

Pangoraw commented Dec 9, 2024

Currently, calling the same function multiple times will generate the IR multiple times due to the tracing.

One idea to circumvent this problem is to update the @trace macro to introduce a MLIR function call as the result op instead of tracing through the function. The cache for this must be specific:

@trace func_call()
func.func @func_call() {
}

func.func @main() {
  ...
  func.call @func_call()
  ...
}
struct Foo
   x::Int
   y::TracedRArray{Float32, 2}
end

f(z::Foo) # 1 version of z.x == 1
f(z::Foo) # 1 version of z.x == 2
f(z::Foo) # 1 version of z.y is size 100x100
f(z::Foo) # 1 version of z.y is size 10x10

Check the values in Julia land, and the MLIR Types to cache the functions among a single trace.

@mofeing
Copy link
Collaborator

mofeing commented Dec 9, 2024

i've to do sth similar manually in #344 but i have some problems in the end:

  1. after every func.call, there are a series of ops stablehlo.transpose + stablehlo.reshape + stablehlo.transpose which are totally useless
  • they are later optimized away but i feel like they shouldn't even appear in the first place
  1. after optimization passes, func.calls appear inlined
  • is there some way to forbid inlining or better improve the inliner cost model?

aside of that, we must check how well does using enzyme.batch and enzyme.autodiff on top of a func.call (which i guess it works perfectly but need to check anyway)

@Pangoraw
Copy link
Collaborator Author

Pangoraw commented Dec 9, 2024

Autodiff is already supported, it should not be too hard to support batch if it isn't yet supported (basically apply enzyme.batch to the call).

stablehlo.transpose + stablehlo.reshape + stablehlo.transpose

This looks like the pattern that we emit for a reshape, are you sure that it is related to the added calls?

after optimization passes, func.calls appear inlined

The inline pass has parameters than we can tweak. Also IIRC this pass was run before autodiff since autodiff of func.call was not yet supported.

@mofeing
Copy link
Collaborator

mofeing commented Dec 9, 2024

This looks like the pattern that we emit for a reshape, are you sure that it is related to the added calls?

mmm you're right, using my PR #344 on this code

using Reactant
using YaoBlocks

θ = ConcreteRNumber(rand())
f(x) = mat(ComplexF64, Rz(x))
@code_hlo optimize = false f(θ)

generates the following MLIR

module {
  func.func @rz_Float64_ComplexF64(%arg0: tensor<f64>) -> tensor<2x2xcomplex<f64>> {
    %cst = stablehlo.constant dense<(0.000000e+00,0.000000e+00)> : tensor<2x2xcomplex<f64>>
    %cst_0 = stablehlo.constant dense<(0.000000e+00,1.000000e+00)> : tensor<complex<f64>>
    %0 = stablehlo.convert %arg0 : (tensor<f64>) -> tensor<complex<f64>>
    %1 = stablehlo.multiply %cst_0, %0 : tensor<complex<f64>>
    %cst_1 = stablehlo.constant dense<(2.000000e+00,0.000000e+00)> : tensor<complex<f64>>
    %2 = stablehlo.divide %1, %cst_1 : tensor<complex<f64>>
    %3 = stablehlo.exponential %2 : tensor<complex<f64>>
    %4 = chlo.conj %3 : tensor<complex<f64>> -> tensor<complex<f64>>
    %5 = stablehlo.broadcast_in_dim %4, dims = [] : (tensor<complex<f64>>) -> tensor<1x1xcomplex<f64>>
    %c = stablehlo.constant dense<1> : tensor<i64>
    %c_2 = stablehlo.constant dense<1> : tensor<i64>
    %6 = stablehlo.subtract %c, %c_2 : tensor<i64>
    %c_3 = stablehlo.constant dense<1> : tensor<i64>
    %c_4 = stablehlo.constant dense<1> : tensor<i64>
    %7 = stablehlo.subtract %c_3, %c_4 : tensor<i64>
    %8 = stablehlo.dynamic_update_slice %cst, %5, %6, %7 : (tensor<2x2xcomplex<f64>>, tensor<1x1xcomplex<f64>>, tensor<i64>, tensor<i64>) -> tensor<2x2xcomplex<f64>>
    %9 = stablehlo.broadcast_in_dim %3, dims = [] : (tensor<complex<f64>>) -> tensor<1x1xcomplex<f64>>
    %c_5 = stablehlo.constant dense<2> : tensor<i64>
    %c_6 = stablehlo.constant dense<1> : tensor<i64>
    %10 = stablehlo.subtract %c_5, %c_6 : tensor<i64>
    %c_7 = stablehlo.constant dense<2> : tensor<i64>
    %c_8 = stablehlo.constant dense<1> : tensor<i64>
    %11 = stablehlo.subtract %c_7, %c_8 : tensor<i64>
    %12 = stablehlo.dynamic_update_slice %8, %9, %10, %11 : (tensor<2x2xcomplex<f64>>, tensor<1x1xcomplex<f64>>, tensor<i64>, tensor<i64>) -> tensor<2x2xcomplex<f64>>
    return %12 : tensor<2x2xcomplex<f64>>
  }
  func.func @main(%arg0: tensor<f64>) -> (tensor<2x2xcomplex<f64>>, tensor<f64>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f64>) -> tensor<f64>
    %1 = call @rz_Float64_ComplexF64(%0) : (tensor<f64>) -> tensor<2x2xcomplex<f64>>
    %2 = stablehlo.transpose %1, dims = [1, 0] : (tensor<2x2xcomplex<f64>>) -> tensor<2x2xcomplex<f64>>
    %3 = stablehlo.transpose %0, dims = [] : (tensor<f64>) -> tensor<f64>
    return %2, %3 : tensor<2x2xcomplex<f64>>, tensor<f64>
  }
}

which doesn't have the transpose + reshape + transpose pattern (the tranposes in there are just conversions between MLIR and Julia layouts).

It's probably me (Tenet) or Yao again, I'll have to revise that.

The inline pass has parameters than we can tweak.

Yeah, but I was thinking if there's a way to mark a func.func or func.call as no-inlinable with an attribute.

Also IIRC this pass was run before autodiff since autodiff of func.call was not yet supported.

I don't understand, the forward-/reverse-rules for func.call are already implemented in EnzymeMLIR right? Or do you mean that we have to bump the versions to be able to run them in Reactant?

@Pangoraw
Copy link
Collaborator Author

Pangoraw commented Dec 9, 2024

Yeah, but I was thinking if there's a way to mark a func.func or func.call as no-inlinable with an attribute.

I think @wsmoses has been working on something like this (llvm/llvm-project#117392).

I don't understand, the forward-/reverse-rules for func.call are already implemented in EnzymeMLIR right? Or do you mean that we have to bump the versions to be able to run them in Reactant?

Yes they are implemented. Not sure if this is up to date on Reactant_jll.

@mofeing
Copy link
Collaborator

mofeing commented Dec 9, 2024

I think @wsmoses has been working on something like this (llvm/llvm-project#117392).

Cool!

Yes they are implemented. Not sure if this is up to date on Reactant_jll.

Yes, it's inside the latest Reactant_jll. The current Enzyme commit in Reactant_jll is f1f4d8e62856286efaa0df8c622711b17aa191c3 which contains the derivatives https://github.com/EnzymeAD/Enzyme/blob/f1f4d8e62856286efaa0df8c622711b17aa191c3/enzyme/Enzyme/MLIR/Implementations/FuncAutoDiffOpInterfaceImpl.cpp

@wsmoses
Copy link
Member

wsmoses commented Dec 10, 2024

yeah I need to finish up that LLVM PR but got distracted kernel'ing

@Pangoraw Pangoraw linked a pull request Dec 12, 2024 that will close this issue
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants