Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add Ops.batch #535

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
Draft

feat: add Ops.batch #535

wants to merge 13 commits into from

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Jan 15, 2025

upstream issues: EnzymeAD/Enzyme-JAX#239

Example

using Reactant
using Reactant: Ops

x_ra = Reactant.to_rarray(rand(Float32, 4, 16))
y_ra = Reactant.to_rarray(rand(Float32, 16, 4))

begin
    dot6(x, y) = (Ops.multiply(x.x, y.x[1]),)
    bfn1(x, y) = Ops.batch(dot6, (; x=x), (; x=(y,)); batch_dims=((; x=2), (; x=(1,))))
    @code_hlo optimize = false bfn1(x_ra, y_ra)
end
module {
  func.func private @dot6_batch(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) {
    %0 = stablehlo.multiply %arg0, %arg1 : tensor<4xf32>
    return %0, %arg0, %arg1 : tensor<4xf32>, tensor<4xf32>, tensor<4xf32>
  }
  func.func @main(%arg0: tensor<16x4xf32>, %arg1: tensor<4x16xf32>) -> (tensor<16x4xf32>, tensor<16x4xf32>, tensor<4x16xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<16x4xf32>) -> tensor<4x16xf32>
    %1 = stablehlo.transpose %arg1, dims = [1, 0] : (tensor<4x16xf32>) -> tensor<16x4xf32>
    %2 = stablehlo.transpose %0, dims = [1, 0] : (tensor<4x16xf32>) -> tensor<16x4xf32>
    %3 = stablehlo.transpose %1, dims = [0, 1] : (tensor<16x4xf32>) -> tensor<16x4xf32>
    %4:3 = enzyme.batch @dot6_batch(%2, %3) {batch_shape = array<i64: 16>} : (tensor<16x4xf32>, tensor<16x4xf32>) -> (tensor<16x4xf32>, tensor<16x4xf32>, tensor<16x4xf32>)
    %5 = stablehlo.transpose %4#0, dims = [1, 0] : (tensor<16x4xf32>) -> tensor<4x16xf32>
    %6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<4x16xf32>) -> tensor<16x4xf32>
    %7 = stablehlo.transpose %0, dims = [1, 0] : (tensor<4x16xf32>) -> tensor<16x4xf32>
    %8 = stablehlo.transpose %1, dims = [1, 0] : (tensor<16x4xf32>) -> tensor<4x16xf32>
    return %6, %7, %8 : tensor<16x4xf32>, tensor<16x4xf32>, tensor<4x16xf32>
  }
}

TODOs:

  • tests
  • docs
  • support linearization and delinearization

@avik-pal avik-pal changed the title feat: implement sort feat: add Ops.batch Jan 15, 2025
@avik-pal avik-pal linked an issue Jan 15, 2025 that may be closed by this pull request
Base automatically changed from ap/sorting to main January 16, 2025 01:13
@avik-pal avik-pal force-pushed the ap/batch branch 3 times, most recently from d230dac to 93c8534 Compare January 16, 2025 03:03
src/Ops.jl Outdated Show resolved Hide resolved
@avik-pal avik-pal force-pushed the ap/batch branch 2 times, most recently from 8379431 to 037fd11 Compare January 16, 2025 12:58
Project.toml Outdated
@@ -9,6 +9,7 @@ CEnum = "fa961155-64e5-5f13-b03f-caf6b980ea82"
Downloads = "f43a241f-c20a-4ad4-852c-f6b1247861c6"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried not using functors 9ef1fff (#535) but it gets quite messy and I couldn't figure out how to reconstruct arguments with corrected batch dims

@avik-pal
Copy link
Collaborator Author

This is going to be a fun rebase 😓

@wsmoses
Copy link
Member

wsmoses commented Jan 17, 2025

oof yeah, the "lets speed up to_rarray" just got in (and cut down time from minutes to seconds for big cases I tested on)

@avik-pal avik-pal force-pushed the ap/batch branch 2 times, most recently from 96edbe0 to a5ee635 Compare January 17, 2025 21:58
@@ -0,0 +1,2 @@
using Reactant, Test

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change

@@ -138,7 +138,7 @@ function make_mlir_fn(
args[i],
(:args, i),
concretein ? Reactant.ConcreteToTraced : Reactant.TracedSetPath;
toscalar,
batchmode
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
batchmode
batchmode,

@avik-pal avik-pal force-pushed the ap/batch branch 2 times, most recently from 4c24ec2 to 38a29ed Compare January 19, 2025 03:16
Comment on lines 551 to +552
TracedRArray{unwrapped_eltype(dest),ndims(dest)},
TracedUtils.elem_apply(bc.f, args...),
Ops.elem_apply(bc.f, args...),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
TracedRArray{unwrapped_eltype(dest),ndims(dest)},
TracedUtils.elem_apply(bc.f, args...),
Ops.elem_apply(bc.f, args...),
TracedRArray{unwrapped_eltype(dest),ndims(dest)}, Ops.elem_apply(bc.f, args...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Implement Reactant.batch function for better batching (vmap too!)
2 participants