Skip to content

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Aug 16, 2025

Stacked PRs:


improve fp8 blockwise gemm perf

Summary

  • Add benchmark script comparing triton vs torch._scaled_mm for blockwise_1x128_128x128 fp8 gemm
  • Minor autotuner config changes to triton gemm
  • scaled_mm is slower for very small shapes (e.g., dim=256) but ~6x faster for llama4 shapes (dim=5120)! Next step is to migrate torchao blockwise autograd func to use scaled_mm.

fp8 blockwise 1x128_128x128 gemm

  • This is the gemm used for:
    • output = input @ weight.t()
    • grad_input = grad_output @ weight
  • Benching shapes representative of w1/w3 and w2 in llama4
    M     N     K  out_dtype         bf16_mm_us    fp8_triton_us    fp8_scaled_mm_us    bf16 tflops/sec    triton tflops/sec    scaled_mm tflops/sec
-----  ----  ----  --------------  ------------  ---------------  ------------------  -----------------  -------------------  ----------------------
16640  5120  8192  torch.bfloat16       2294.62          3494.42             1383.54            608.319              399.456                 1008.91
16640  8192  5120  torch.bfloat16       2322.43          3800.48             1392.75            601.036              367.286                 1002.23

fp8 blockwise 1x128_128x1 gemm

  • This is the gemm used for:
    • grad_weight = grad_output_t @ input
  • Benching shapes representative of w1/w3 and w2 in llama4
    M     N     K  out_dtype         bf16_mm_us    fp8_triton_us    fp8_scaled_mm_us    bf16 tflops/sec    triton tflops/sec    scaled_mm tflops/sec
-----  ----  ----  --------------  ------------  ---------------  ------------------  -----------------  -------------------  ----------------------
16640  5120  8192  torch.bfloat16        2333.1          2572.58             1726.98            598.286              542.594                 808.271
16640  8192  5120  torch.bfloat16        2290.9          2754.64             1729.66            609.309              506.732                 807.015

Copy link

pytorch-bot bot commented Aug 16, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2784

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Cancelled Job

As of commit c288547 with merge base 1526dfe (image):

NEW FAILURE - The following job has failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

danielvegamyhre added a commit that referenced this pull request Aug 16, 2025
stack-info: PR: #2784, branch: danielvegamyhre/stack/43
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/43 branch from a42133f to 5d45f01 Compare August 16, 2025 17:12
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 16, 2025
@danielvegamyhre danielvegamyhre added the topic: not user facing Use this tag if you don't want this PR to show up in release notes label Aug 16, 2025
danielvegamyhre added a commit that referenced this pull request Aug 17, 2025
stack-info: PR: #2784, branch: danielvegamyhre/stack/43
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/43 branch from 5d45f01 to da736d3 Compare August 17, 2025 00:08
stack-info: PR: #2784, branch: danielvegamyhre/stack/43
@danielvegamyhre danielvegamyhre force-pushed the danielvegamyhre/stack/43 branch from da736d3 to c288547 Compare August 17, 2025 16:16
@danielvegamyhre danielvegamyhre changed the title improve fp8 blockwise gemm perf bench torch._scaled_mm for fp8 blockwise with llama4 shapes Aug 18, 2025
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented Aug 18, 2025

@vkuzo I'm OOO Monday and Tuesday, let's cancel our meeting this week. (I'm not sure if meetings are autodeclined for PTO so I wanted to reach out to make sure)

@danielvegamyhre danielvegamyhre changed the title bench torch._scaled_mm for fp8 blockwise with llama4 shapes [fp8 blockwise] bench torch._scaled_mm for fp8 blockwise with llama4 shapes Aug 20, 2025
@danielvegamyhre danielvegamyhre changed the title [fp8 blockwise] bench torch._scaled_mm for fp8 blockwise with llama4 shapes improve fp8 blockwise gemm perf Aug 20, 2025
@danielvegamyhre
Copy link
Contributor Author

confirmed test failure is unrelated to change and talked with feature owner about fixing

@danielvegamyhre danielvegamyhre merged commit fbe08c3 into main Aug 20, 2025
16 of 18 checks passed
liangel-02 pushed a commit that referenced this pull request Aug 25, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: not user facing Use this tag if you don't want this PR to show up in release notes
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants