-
Notifications
You must be signed in to change notification settings - Fork 322
Closed
Labels
Description
Summary
I noticed in this PR #2894 where I wrapped a Triton kernel in triton_op
and wrap_triton
to compose with torch.compile, it severely degraded kernel performance.
Example 1
Without wrap_triton:
(ao) [[email protected] ~/ao (8cfccaee|REBASE-i|danielvegamyhre/stack/64)]$ CUDA_VISIBLE_DEVICES=4 python benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py -s
input_shape torch_time_us triton_time_us torch_mem_bw_gbps triton_mem_bw_gbps triton_speedup
------------- --------------- ---------------- ------------------- -------------------- ----------------
(16640, 160) 844.832 63.456 6.448 85.85 13.31x
With wrap_triton:
(ao) [[email protected] ~/ao (8cfccaee|REBASE-i|danielvegamyhre/stack/64)]$ CUDA_VISIBLE_DEVICES=4 python benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py -s
input_shape torch_time_us triton_time_us torch_mem_bw_gbps triton_mem_bw_gbps triton_speedup
------------- --------------- ---------------- ------------------- -------------------- ----------------
(16640, 160) 848.608 136.16 6.395 39.859 6.23x
Is this expected? seems like a bug to me. cc @zou3519 @drisspg
Version
This is on a B200 devgpu with the following versions:
pytorch-triton 3.4.0+gitf7888497 pypi_0 pypi
torch 2.9.0.dev20250822+cu128 pypi_0 pypi
torchao 0.13.0+gitb663faf20 dev_0 <develop>