Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: automatic batching of code [currently very wip] #233

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Nov 6, 2024

@wsmoses what is the exact specification for the batch_shape argument?

@avik-pal
Copy link
Collaborator Author

avik-pal commented Nov 6, 2024

more concretely if I have:

  func.func private @"-_broadcast_scalar"(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
    %1 = stablehlo.negate %0 : tensor<f32>
    %2 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
    %3 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
    return %2, %3 : tensor<f32>, tensor<f32>
  }
  func.func private @"-_mapslice"(%arg0: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
    %1 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
    %2:2 = enzyme.batch @"-_broadcast_scalar"(%1) {batch_shape = array<i64: 5>} : (tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
    %3 = stablehlo.transpose %2#0, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
    %4 = stablehlo.transpose %0, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
    return %3, %4 : tensor<5xf32>, tensor<5xf32>
  }

I want to batch -_mapslice along dims = (1, 3) in the following

  func.func @main(%arg0: tensor<2x5x4xf32>) -> (tensor<2x5x4xf32>, tensor<2x5x4xf32>) {
    %0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x5x4xf32>) -> tensor<4x5x2xf32>
    %1:2 = enzyme.batch @"-_mapslice"(%0) {batch_shape = array<i64: 4, 5, 2>} : (tensor<4x5x2xf32>) -> (tensor<4x5x2xf32>, tensor<4x5x2xf32>)
    %2 = stablehlo.transpose %1#0, dims = [2, 1, 0] : (tensor<4x5x2xf32>) -> tensor<2x5x4xf32>
    %3 = stablehlo.transpose %1#1, dims = [2, 1, 0] : (tensor<4x5x2xf32>) -> tensor<2x5x4xf32>
    return %2, %3 : tensor<2x5x4xf32>, tensor<2x5x4xf32>
  }

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 0eaf929 Previous: 7492957 Ratio
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1322438196 ns 1348963538 ns 0.98
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1316616944 ns 1391311320 ns 0.95
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1382717636 ns 1338620474 ns 1.03
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 3205652086 ns 3244388370 ns 0.99
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux 312859638 ns 241874702 ns 1.29
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 5214679683 ns 5188886490 ns 1.00
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5391065775 ns 6246539107 ns 0.86
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5213241200 ns 5084093600 ns 1.03
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 7702321219 ns 7518949936 ns 1.02
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 32302105402 ns 36850329779 ns 0.88
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1303829729 ns 1317971201 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1294724033 ns 1310258445 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1366131062.5 ns 1324769950 ns 1.03
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 3187279151 ns 3116535236 ns 1.02
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux 8674339 ns 13540610 ns 0.64
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1571074487 ns 1560502992 ns 1.01
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1548963490 ns 1551892442 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1561060149 ns 1542824845 ns 1.01
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3298074342 ns 3281095729 ns 1.01
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2869244150 ns 2720425528 ns 1.05
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1322333000 ns 1322064276 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1307890478 ns 1335917070 ns 0.98
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1291726367 ns 1300736588 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 3154878431 ns 3166175682 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux 22766838 ns 22202743 ns 1.03
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2177974901 ns 2145782423 ns 1.02
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2160680379 ns 2167690743 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2148416243 ns 2149486683 ns 1.00
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3890960910 ns 3922128663 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 5919682668.5 ns 6198372133.5 ns 0.96
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1323854214 ns 1348660676 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1364588511.5 ns 1334165596 ns 1.02
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1326865864.5 ns 1337531256 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 3178970766 ns 3231963773 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux 7555941 ns 7533302.5 ns 1.00
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1433740585 ns 1444700329 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1430726259 ns 1440226485 ns 0.99
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1425878660 ns 1425208361 ns 1.00
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3140899166 ns 3190100551 ns 0.98
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1710140106 ns 1405673485 ns 1.22
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1525601209 ns 1308447986 ns 1.17
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1332066970.5 ns 1312105536.5 ns 1.02
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1330883945 ns 1370086342.5 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 3183142321 ns 3275216428 ns 0.97
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux 14107204.5 ns 15341602 ns 0.92
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1765120978 ns 1710785385 ns 1.03
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1716175449 ns 1690355099 ns 1.02
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1714874668 ns 1692688324 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3428465051 ns 3469285688 ns 0.99
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3287670889.5 ns 3370052934.5 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1319205772 ns 1334400878 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1325339613 ns 1306011648 ns 1.01
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1448741523 ns 1324306907 ns 1.09
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 3262740904 ns 3240905394 ns 1.01
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux 25574211.5 ns 25731279 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2163282741 ns 2183232912 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2123709459 ns 2164389519 ns 0.98
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2158670774 ns 2169681573 ns 0.99
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3933710244 ns 3937216620 ns 1.00
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 7830953027.5 ns 6330955348 ns 1.24
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) 1373695348 ns 1299440434 ns 1.06
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant 1319882898 ns 1330626200 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) 1344069864 ns 1313109408 ns 1.02
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) 3165833610 ns 3193101313 ns 0.99
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux 54164248.5 ns 72107323.5 ns 0.75
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 3081915232 ns 2971168120 ns 1.04
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3069794354 ns 2905935051 ns 1.06
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 3073262827 ns 2979248113 ns 1.03
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4887622365 ns 4903544029 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 14133325079 ns 15868504232 ns 0.89
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) 1317901088 ns 1351000519 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant 1436404573 ns 1361060741 ns 1.06
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) 1320045192 ns 1282590367 ns 1.03
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) 3437504052 ns 3183466738 ns 1.08
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux 79939174 ns 68241711.5 ns 1.17
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3103689416 ns 3167721219 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3160035788 ns 3257591458 ns 0.97
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3180371400 ns 3236033003 ns 0.98
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 5094434549 ns 4991442581 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 12511079079 ns 12380624409 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) 1336177364 ns 1313391133 ns 1.02
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant 1323073468 ns 1334090768 ns 0.99
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) 1354149615 ns 1304546461 ns 1.04
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) 3217243816 ns 3179701072 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux 20311218 ns 23252444 ns 0.87
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1872910555 ns 1844700794 ns 1.02
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1866569372 ns 1823452266 ns 1.02
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1841211551 ns 1816473309 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3596184702 ns 3573136282 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 3421263128 ns 3670592736 ns 0.93

This comment was automatically generated by workflow using github-action-benchmark.

@wsmoses
Copy link
Member

wsmoses commented Nov 6, 2024

Currently it’s roughly akin to prepending the shape to each arg in the function being called.

so it basically assumes the all dims at the end are batched.

fixable, but requires some API changes

@wsmoses
Copy link
Member

wsmoses commented Nov 6, 2024

There’s a temporary lazy option of doing transpose to get things to the end, batching, then transposing back

@avik-pal
Copy link
Collaborator Author

avik-pal commented Nov 6, 2024

There’s a temporary lazy option of doing transpose to get things to the end, batching, then transposing back

This is similar to what most of the stablehlo ops that support dims do, so I think we can just follow that for now.

@@ -333,6 +327,7 @@ function make_tracer(
mode;
toscalar=false,
tobatch=nothing,
batchdims=nothing,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why add a batchdims to linearization? shouldn't be needed and can make it harder to linearize

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

temporary means for prototyping. It is just a generalization of toscalar, so I want to fuse those options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants