Skip to content

Using triton_op + wrap_triton introduces kernel performance regression #2898

@danielvegamyhre

Description

@danielvegamyhre

Summary

I noticed in this PR #2894 where I wrapped a Triton kernel in triton_op and wrap_triton to compose with torch.compile, it severely degraded kernel performance.

Example 1

Without wrap_triton:

(ao) [[email protected] ~/ao (8cfccaee|REBASE-i|danielvegamyhre/stack/64)]$ CUDA_VISIBLE_DEVICES=4 python benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py  -s

input_shape      torch_time_us    triton_time_us    torch_mem_bw_gbps    triton_mem_bw_gbps  triton_speedup
-------------  ---------------  ----------------  -------------------  --------------------  ----------------
(16640, 160)           844.832            63.456                6.448                 85.85  13.31x

With wrap_triton:

(ao) [[email protected] ~/ao (8cfccaee|REBASE-i|danielvegamyhre/stack/64)]$ CUDA_VISIBLE_DEVICES=4 python benchmarks/prototype/moe_training/benchmark_2d_blocked_swizzle_scale_kernels.py  -s

input_shape      torch_time_us    triton_time_us    torch_mem_bw_gbps    triton_mem_bw_gbps  triton_speedup
-------------  ---------------  ----------------  -------------------  --------------------  ----------------
(16640, 160)           848.608            136.16                6.395                39.859  6.23x

Is this expected? seems like a bug to me. cc @zou3519 @drisspg

Version

This is on a B200 devgpu with the following versions:

pytorch-triton            3.4.0+gitf7888497          pypi_0    pypi
torch                     2.9.0.dev20250822+cu128          pypi_0    pypi
torchao                   0.13.0+gitb663faf20           dev_0    <develop>

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingmx

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions