-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: automatic batching of code [currently very wip] #233
base: main
Are you sure you want to change the base?
Conversation
more concretely if I have: func.func private @"-_broadcast_scalar"(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
%0 = stablehlo.transpose %arg0, dims = [] : (tensor<f32>) -> tensor<f32>
%1 = stablehlo.negate %0 : tensor<f32>
%2 = stablehlo.transpose %1, dims = [] : (tensor<f32>) -> tensor<f32>
%3 = stablehlo.transpose %0, dims = [] : (tensor<f32>) -> tensor<f32>
return %2, %3 : tensor<f32>, tensor<f32>
}
func.func private @"-_mapslice"(%arg0: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
%0 = stablehlo.transpose %arg0, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
%1 = stablehlo.broadcast_in_dim %0, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
%2:2 = enzyme.batch @"-_broadcast_scalar"(%1) {batch_shape = array<i64: 5>} : (tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>)
%3 = stablehlo.transpose %2#0, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
%4 = stablehlo.transpose %0, dims = [0] : (tensor<5xf32>) -> tensor<5xf32>
return %3, %4 : tensor<5xf32>, tensor<5xf32>
} I want to batch func.func @main(%arg0: tensor<2x5x4xf32>) -> (tensor<2x5x4xf32>, tensor<2x5x4xf32>) {
%0 = stablehlo.transpose %arg0, dims = [2, 1, 0] : (tensor<2x5x4xf32>) -> tensor<4x5x2xf32>
%1:2 = enzyme.batch @"-_mapslice"(%0) {batch_shape = array<i64: 4, 5, 2>} : (tensor<4x5x2xf32>) -> (tensor<4x5x2xf32>, tensor<4x5x2xf32>)
%2 = stablehlo.transpose %1#0, dims = [2, 1, 0] : (tensor<4x5x2xf32>) -> tensor<2x5x4xf32>
%3 = stablehlo.transpose %1#1, dims = [2, 1, 0] : (tensor<4x5x2xf32>) -> tensor<2x5x4xf32>
return %2, %3 : tensor<2x5x4xf32>, tensor<2x5x4xf32>
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reactant.jl Benchmarks
Benchmark suite | Current: 0eaf929 | Previous: 7492957 | Ratio |
---|---|---|---|
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1322438196 ns |
1348963538 ns |
0.98 |
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant |
1316616944 ns |
1391311320 ns |
0.95 |
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1382717636 ns |
1338620474 ns |
1.03 |
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) |
3205652086 ns |
3244388370 ns |
0.99 |
ViT base (256 x 256 x 3 x 32)/forward/CUDA/Lux |
312859638 ns |
241874702 ns |
1.29 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) |
5214679683 ns |
5188886490 ns |
1.00 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant |
5391065775 ns |
6246539107 ns |
0.86 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) |
5213241200 ns |
5084093600 ns |
1.03 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) |
7702321219 ns |
7518949936 ns |
1.02 |
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux |
32302105402 ns |
36850329779 ns |
0.88 |
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1303829729 ns |
1317971201 ns |
0.99 |
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant |
1294724033 ns |
1310258445 ns |
0.99 |
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1366131062.5 ns |
1324769950 ns |
1.03 |
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) |
3187279151 ns |
3116535236 ns |
1.02 |
ViT small (256 x 256 x 3 x 4)/forward/CUDA/Lux |
8674339 ns |
13540610 ns |
0.64 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) |
1571074487 ns |
1560502992 ns |
1.01 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant |
1548963490 ns |
1551892442 ns |
1.00 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) |
1561060149 ns |
1542824845 ns |
1.01 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) |
3298074342 ns |
3281095729 ns |
1.01 |
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux |
2869244150 ns |
2720425528 ns |
1.05 |
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1322333000 ns |
1322064276 ns |
1.00 |
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant |
1307890478 ns |
1335917070 ns |
0.98 |
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1291726367 ns |
1300736588 ns |
0.99 |
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) |
3154878431 ns |
3166175682 ns |
1.00 |
ViT tiny (256 x 256 x 3 x 32)/forward/CUDA/Lux |
22766838 ns |
22202743 ns |
1.03 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) |
2177974901 ns |
2145782423 ns |
1.02 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant |
2160680379 ns |
2167690743 ns |
1.00 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) |
2148416243 ns |
2149486683 ns |
1.00 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) |
3890960910 ns |
3922128663 ns |
0.99 |
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux |
5919682668.5 ns |
6198372133.5 ns |
0.96 |
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1323854214 ns |
1348660676 ns |
0.98 |
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant |
1364588511.5 ns |
1334165596 ns |
1.02 |
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1326865864.5 ns |
1337531256 ns |
0.99 |
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) |
3178970766 ns |
3231963773 ns |
0.98 |
ViT tiny (256 x 256 x 3 x 4)/forward/CUDA/Lux |
7555941 ns |
7533302.5 ns |
1.00 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) |
1433740585 ns |
1444700329 ns |
0.99 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant |
1430726259 ns |
1440226485 ns |
0.99 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) |
1425878660 ns |
1425208361 ns |
1.00 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) |
3140899166 ns |
3190100551 ns |
0.98 |
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux |
1710140106 ns |
1405673485 ns |
1.22 |
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1525601209 ns |
1308447986 ns |
1.17 |
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant |
1332066970.5 ns |
1312105536.5 ns |
1.02 |
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1330883945 ns |
1370086342.5 ns |
0.97 |
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) |
3183142321 ns |
3275216428 ns |
0.97 |
ViT tiny (256 x 256 x 3 x 16)/forward/CUDA/Lux |
14107204.5 ns |
15341602 ns |
0.92 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) |
1765120978 ns |
1710785385 ns |
1.03 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant |
1716175449 ns |
1690355099 ns |
1.02 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) |
1714874668 ns |
1692688324 ns |
1.01 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) |
3428465051 ns |
3469285688 ns |
0.99 |
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux |
3287670889.5 ns |
3370052934.5 ns |
0.98 |
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1319205772 ns |
1334400878 ns |
0.99 |
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant |
1325339613 ns |
1306011648 ns |
1.01 |
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1448741523 ns |
1324306907 ns |
1.09 |
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) |
3262740904 ns |
3240905394 ns |
1.01 |
ViT small (256 x 256 x 3 x 16)/forward/CUDA/Lux |
25574211.5 ns |
25731279 ns |
0.99 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) |
2163282741 ns |
2183232912 ns |
0.99 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant |
2123709459 ns |
2164389519 ns |
0.98 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) |
2158670774 ns |
2169681573 ns |
0.99 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) |
3933710244 ns |
3937216620 ns |
1.00 |
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux |
7830953027.5 ns |
6330955348 ns |
1.24 |
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1373695348 ns |
1299440434 ns |
1.06 |
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant |
1319882898 ns |
1330626200 ns |
0.99 |
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1344069864 ns |
1313109408 ns |
1.02 |
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Reactant (optimize = :only_enzyme) |
3165833610 ns |
3193101313 ns |
0.99 |
ViT small (256 x 256 x 3 x 32)/forward/CUDA/Lux |
54164248.5 ns |
72107323.5 ns |
0.75 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) |
3081915232 ns |
2971168120 ns |
1.04 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant |
3069794354 ns |
2905935051 ns |
1.06 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) |
3073262827 ns |
2979248113 ns |
1.03 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) |
4887622365 ns |
4903544029 ns |
1.00 |
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux |
14133325079 ns |
15868504232 ns |
0.89 |
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1317901088 ns |
1351000519 ns |
0.98 |
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant |
1436404573 ns |
1361060741 ns |
1.06 |
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1320045192 ns |
1282590367 ns |
1.03 |
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Reactant (optimize = :only_enzyme) |
3437504052 ns |
3183466738 ns |
1.08 |
ViT base (256 x 256 x 3 x 16)/forward/CUDA/Lux |
79939174 ns |
68241711.5 ns |
1.17 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) |
3103689416 ns |
3167721219 ns |
0.98 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant |
3160035788 ns |
3257591458 ns |
0.97 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) |
3180371400 ns |
3236033003 ns |
0.98 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) |
5094434549 ns |
4991442581 ns |
1.02 |
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux |
12511079079 ns |
12380624409 ns |
1.01 |
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :after_enzyme) |
1336177364 ns |
1313391133 ns |
1.02 |
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant |
1323073468 ns |
1334090768 ns |
0.99 |
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :before_enzyme) |
1354149615 ns |
1304546461 ns |
1.04 |
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Reactant (optimize = :only_enzyme) |
3217243816 ns |
3179701072 ns |
1.01 |
ViT base (256 x 256 x 3 x 4)/forward/CUDA/Lux |
20311218 ns |
23252444 ns |
0.87 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) |
1872910555 ns |
1844700794 ns |
1.02 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant |
1866569372 ns |
1823452266 ns |
1.02 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) |
1841211551 ns |
1816473309 ns |
1.01 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) |
3596184702 ns |
3573136282 ns |
1.01 |
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux |
3421263128 ns |
3670592736 ns |
0.93 |
This comment was automatically generated by workflow using github-action-benchmark.
Currently it’s roughly akin to prepending the shape to each arg in the function being called. so it basically assumes the all dims at the end are batched. fixable, but requires some API changes |
There’s a temporary lazy option of doing transpose to get things to the end, batching, then transposing back |
This is similar to what most of the stablehlo ops that support dims do, so I think we can just follow that for now. |
@@ -333,6 +327,7 @@ function make_tracer( | |||
mode; | |||
toscalar=false, | |||
tobatch=nothing, | |||
batchdims=nothing, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why add a batchdims to linearization? shouldn't be needed and can make it harder to linearize
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
temporary means for prototyping. It is just a generalization of toscalar, so I want to fuse those options
@wsmoses what is the exact specification for the
batch_shape
argument?