Skip to content

Commit 7927d66

Browse files
[fp8 blockwise] wrap triton quantization kernels in custom ops for torch.compile compatibility
stack-info: PR: #2829, branch: danielvegamyhre/stack/47
1 parent c77a12e commit 7927d66

File tree

5 files changed

+64
-54
lines changed

5 files changed

+64
-54
lines changed

benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from triton.testing import do_bench
1616

1717
from torchao.prototype.blockwise_fp8_training.kernels import (
18-
fp8_blockwise_act_quant_lhs,
19-
fp8_blockwise_weight_quant_transposed_rhs,
18+
triton_fp8_blockwise_act_quant_lhs,
19+
triton_fp8_blockwise_weight_quant_transposed_rhs,
2020
triton_fp8_gemm_1x128_128x128,
2121
)
2222

@@ -78,8 +78,8 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7878
M, N, K = config.m, config.n, config.k
7979
A = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
8080
B = torch.randn(N, K, dtype=config.out_dtype, device="cuda")
81-
A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=torch.float8_e4m3fn)
82-
B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(
81+
A_q, A_s = triton_fp8_blockwise_act_quant_lhs(A, dtype=torch.float8_e4m3fn)
82+
B_t_q, B_t_s = triton_fp8_blockwise_weight_quant_transposed_rhs(
8383
B, dtype=torch.float8_e4m3fn
8484
)
8585

benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from triton.testing import do_bench
1616

1717
from torchao.prototype.blockwise_fp8_training.kernels import (
18-
fp8_blockwise_act_quant_rhs,
19-
fp8_blockwise_act_quant_transposed_lhs,
18+
triton_fp8_blockwise_act_quant_rhs,
19+
triton_fp8_blockwise_act_quant_transposed_lhs,
2020
triton_fp8_gemm_1x128_128x1,
2121
)
2222

@@ -78,8 +78,10 @@ def run_experiment(config: ExperimentConfig) -> ExperimentResult:
7878
M, N, K = config.m, config.n, config.k
7979
A = torch.randn(M, N, dtype=config.out_dtype, device="cuda")
8080
B = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
81-
A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=torch.float8_e4m3fn)
82-
B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=torch.float8_e4m3fn)
81+
A_t_q, A_t_s = triton_fp8_blockwise_act_quant_transposed_lhs(
82+
A, dtype=torch.float8_e4m3fn
83+
)
84+
B_q, B_s = triton_fp8_blockwise_act_quant_rhs(B, dtype=torch.float8_e4m3fn)
8385

8486
def warmup(func, *args, **kwargs):
8587
for _ in range(10):

test/prototype/blockwise_fp8_training/test_blockwise_kernels.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,14 @@
1212
from packaging import version
1313
from torchao.float8.float8_utils import compute_error
1414
from torchao.prototype.blockwise_fp8_training.kernels import (
15-
fp8_blockwise_act_quant_lhs,
16-
fp8_blockwise_act_quant_rhs,
17-
fp8_blockwise_act_quant_transposed_lhs,
18-
fp8_blockwise_weight_quant_rhs,
19-
fp8_blockwise_weight_quant_transposed_rhs,
2015
torch_blockwise_scale_act_quant_lhs,
2116
torch_blockwise_scale_act_quant_rhs,
2217
torch_blockwise_scale_weight_quant,
18+
triton_fp8_blockwise_act_quant_lhs,
19+
triton_fp8_blockwise_act_quant_rhs,
20+
triton_fp8_blockwise_act_quant_transposed_lhs,
21+
triton_fp8_blockwise_weight_quant_rhs,
22+
triton_fp8_blockwise_weight_quant_transposed_rhs,
2323
triton_fp8_gemm_1x128_128x1,
2424
triton_fp8_gemm_1x128_128x128,
2525
)
@@ -51,8 +51,8 @@ def test_triton_fp8_gemm_1x128_128x128(M, N, K, dtype):
5151
A = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
5252
B = torch.randn(N, K, dtype=torch.bfloat16, device="cuda")
5353
C = A @ B.T
54-
A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=dtype)
55-
B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype)
54+
A_q, A_s = triton_fp8_blockwise_act_quant_lhs(A, dtype=dtype)
55+
B_t_q, B_t_s = triton_fp8_blockwise_weight_quant_transposed_rhs(B, dtype=dtype)
5656
C_q = triton_fp8_gemm_1x128_128x128(
5757
A_q, B_t_q, A_s, B_t_s, out_dtype=torch.bfloat16
5858
)
@@ -76,8 +76,8 @@ def test_triton_fp8_gemm_1x128_128x1(M, N, K, dtype):
7676
A = torch.randn(K, M, dtype=torch.bfloat16, device="cuda")
7777
B = torch.randn(K, N, dtype=torch.bfloat16, device="cuda")
7878
C = A.T @ B
79-
A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype)
80-
B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=dtype)
79+
A_t_q, A_t_s = triton_fp8_blockwise_act_quant_transposed_lhs(A, dtype=dtype)
80+
B_q, B_s = triton_fp8_blockwise_act_quant_rhs(B, dtype=dtype)
8181
C_q = triton_fp8_gemm_1x128_128x1(A_t_q, B_q, A_t_s, B_s, out_dtype=torch.bfloat16)
8282

8383
assert not C_q.isnan().any(), "C_q must not contain NaNs"
@@ -102,7 +102,7 @@ def test_triton_quantize_fp8_act_quant_lhs(block_size):
102102
x[0, :block_size] = 0.0
103103

104104
# Get the quantized tensor and reciprocal scales using triton implementation
105-
triton_fp8, triton_scale = fp8_blockwise_act_quant_lhs(
105+
triton_fp8, triton_scale = triton_fp8_blockwise_act_quant_lhs(
106106
x,
107107
block_size=block_size,
108108
)
@@ -149,7 +149,7 @@ def test_triton_quantize_fp8_act_quant_rhs(block_size: int):
149149
x[:block_size, :block_size] = 0.0
150150

151151
# Get the quantized tensor and reciprocal scales using triton implementation
152-
triton_fp8, triton_scale = fp8_blockwise_act_quant_rhs(
152+
triton_fp8, triton_scale = triton_fp8_blockwise_act_quant_rhs(
153153
x,
154154
block_size=block_size,
155155
)
@@ -196,7 +196,7 @@ def test_triton_quantize_fp8_act_quant_transposed_lhs(M, K, block_size: int):
196196
x[0, :block_size] = 0.0
197197

198198
# Get the quantized tensor and reciprocal scales using triton implementation
199-
triton_fp8, triton_scale = fp8_blockwise_act_quant_transposed_lhs(
199+
triton_fp8, triton_scale = triton_fp8_blockwise_act_quant_transposed_lhs(
200200
x,
201201
block_size=block_size,
202202
)
@@ -245,7 +245,7 @@ def test_triton_quantize_fp8_weight_quant_rhs(M, K, block_size: int):
245245
x[:block_size, :block_size] = 0.0
246246

247247
# Get the quantized tensor and reciprocal scales using triton implementation
248-
triton_fp8, triton_scale = fp8_blockwise_weight_quant_rhs(
248+
triton_fp8, triton_scale = triton_fp8_blockwise_weight_quant_rhs(
249249
x,
250250
block_size=block_size,
251251
)
@@ -292,7 +292,7 @@ def test_triton_quantize_fp8_weight_quant_transposed_rhs(block_size: int):
292292
x[:block_size, :block_size] = 0.0
293293

294294
# Get the quantized tensor and reciprocal scales using triton implementation
295-
triton_fp8, triton_scale = fp8_blockwise_weight_quant_transposed_rhs(
295+
triton_fp8, triton_scale = triton_fp8_blockwise_weight_quant_transposed_rhs(
296296
x,
297297
block_size=block_size,
298298
)

torchao/prototype/blockwise_fp8_training/kernels.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch
1010
import triton
1111
import triton.language as tl
12+
from torch.library import triton_op, wrap_triton
1213

1314
from torchao.prototype.moe_training.utils import (
1415
_is_column_major,
@@ -119,7 +120,7 @@ def triton_fp8_gemm_1x128_128x128(
119120
triton.cdiv(M, META["BLOCK_SIZE_M"]),
120121
triton.cdiv(N, META["BLOCK_SIZE_N"]),
121122
)
122-
triton_fp8_gemm_1x128_128x128_kernel[grid](
123+
wrap_triton(triton_fp8_gemm_1x128_128x128_kernel)[grid](
123124
a,
124125
a.stride(0),
125126
a.stride(1),
@@ -234,7 +235,7 @@ def triton_fp8_gemm_1x128_128x1(
234235
triton.cdiv(M, META["BLOCK_SIZE_M"]),
235236
triton.cdiv(N, META["BLOCK_SIZE_N"]),
236237
)
237-
triton_fp8_gemm_1x128_128x1_kernel[grid](
238+
wrap_triton(triton_fp8_gemm_1x128_128x1_kernel)[grid](
238239
a,
239240
a.stride(0),
240241
a.stride(1),
@@ -281,7 +282,7 @@ def triton_fp8_gemm_1x128_128x1(
281282

282283
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
283284
@triton.jit
284-
def fp8_blockwise_act_quant_lhs_kernel(
285+
def triton_fp8_blockwise_act_quant_lhs_kernel(
285286
x_ptr,
286287
x_stride_dim_0,
287288
x_stride_dim_1,
@@ -327,7 +328,8 @@ def fp8_blockwise_act_quant_lhs_kernel(
327328
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale))
328329

329330

330-
def fp8_blockwise_act_quant_lhs(
331+
@triton_op("torchao::triton_fp8_blockwise_act_quant_lhs", mutates_args={})
332+
def triton_fp8_blockwise_act_quant_lhs(
331333
x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
332334
) -> Tuple[torch.Tensor, torch.Tensor]:
333335
"""
@@ -352,7 +354,7 @@ def fp8_blockwise_act_quant_lhs(
352354
triton.cdiv(M, meta["NUM_GROUPS"]),
353355
triton.cdiv(K, meta["BLOCK_SIZE"]),
354356
)
355-
fp8_blockwise_act_quant_lhs_kernel[grid](
357+
wrap_triton(triton_fp8_blockwise_act_quant_lhs_kernel)[grid](
356358
x,
357359
x.stride(0),
358360
x.stride(1),
@@ -372,7 +374,7 @@ def fp8_blockwise_act_quant_lhs(
372374

373375
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
374376
@triton.jit
375-
def fp8_blockwise_act_quant_rhs_kernel(
377+
def triton_fp8_blockwise_act_quant_rhs_kernel(
376378
x_ptr,
377379
x_stride_dim_0,
378380
x_stride_dim_1,
@@ -420,7 +422,8 @@ def fp8_blockwise_act_quant_rhs_kernel(
420422
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale))
421423

422424

423-
def fp8_blockwise_act_quant_rhs(
425+
@triton_op("torchao::triton_fp8_blockwise_act_quant_rhs", mutates_args={})
426+
def triton_fp8_blockwise_act_quant_rhs(
424427
x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
425428
) -> Tuple[torch.Tensor, torch.Tensor]:
426429
"""
@@ -444,7 +447,7 @@ def fp8_blockwise_act_quant_rhs(
444447
triton.cdiv(M, meta["BLOCK_SIZE"]),
445448
triton.cdiv(K, meta["NUM_GROUPS"]),
446449
)
447-
fp8_blockwise_act_quant_rhs_kernel[grid](
450+
wrap_triton(triton_fp8_blockwise_act_quant_rhs_kernel)[grid](
448451
x,
449452
x.stride(0),
450453
x.stride(1),
@@ -464,7 +467,7 @@ def fp8_blockwise_act_quant_rhs(
464467

465468
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
466469
@triton.jit
467-
def fp8_blockwise_act_quant_transposed_lhs_kernel(
470+
def triton_fp8_blockwise_act_quant_transposed_lhs_kernel(
468471
x_ptr,
469472
x_stride_dim_0,
470473
x_stride_dim_1,
@@ -524,7 +527,8 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
524527
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask)
525528

526529

527-
def fp8_blockwise_act_quant_transposed_lhs(
530+
@triton_op("torchao::triton_fp8_blockwise_act_quant_transposed_lhs", mutates_args={})
531+
def triton_fp8_blockwise_act_quant_transposed_lhs(
528532
x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
529533
) -> Tuple[torch.Tensor, torch.Tensor]:
530534
assert x.is_contiguous(), "Input tensor must be contiguous"
@@ -550,7 +554,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
550554
triton.cdiv(K, meta["NUM_GROUPS"]),
551555
)
552556

553-
fp8_blockwise_act_quant_transposed_lhs_kernel[grid](
557+
wrap_triton(triton_fp8_blockwise_act_quant_transposed_lhs_kernel)[grid](
554558
x,
555559
x.stride(0),
556560
x.stride(1),
@@ -570,7 +574,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
570574

571575
@triton.autotune(configs=quant_kernel_configs, key=["M", "N"])
572576
@triton.jit
573-
def fp8_blockwise_weight_quant_rhs_kernel(
577+
def triton_fp8_blockwise_weight_quant_rhs_kernel(
574578
x_ptr,
575579
x_stride_dim_0,
576580
x_stride_dim_1,
@@ -615,8 +619,9 @@ def fp8_blockwise_weight_quant_rhs_kernel(
615619
tl.store(s_ptr + scale_m_off + scale_n_off, tl.div_rn(1.0, scale))
616620

617621

618-
def fp8_blockwise_weight_quant_rhs(
619-
x: torch.Tensor, block_size: int = 128, dtype=torch.float8_e4m3fn
622+
@triton_op("torchao::triton_fp8_blockwise_weight_quant_rhs", mutates_args={})
623+
def triton_fp8_blockwise_weight_quant_rhs(
624+
x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
620625
) -> Tuple[torch.Tensor, torch.Tensor]:
621626
assert x.is_contiguous(), "Input tensor must be contiguous"
622627
assert x.dim() == 2, "Input tensor must have 2 dimensions"
@@ -638,7 +643,7 @@ def fp8_blockwise_weight_quant_rhs(
638643
triton.cdiv(M, meta["BLOCK_SIZE"]),
639644
triton.cdiv(N, meta["BLOCK_SIZE"]),
640645
)
641-
fp8_blockwise_weight_quant_rhs_kernel[grid](
646+
wrap_triton(triton_fp8_blockwise_weight_quant_rhs_kernel)[grid](
642647
x,
643648
x.stride(0),
644649
x.stride(1),
@@ -658,7 +663,7 @@ def fp8_blockwise_weight_quant_rhs(
658663

659664
@triton.autotune(configs=quant_kernel_configs, key=["M", "N"])
660665
@triton.jit
661-
def fp8_blockwise_weight_quant_transposed_rhs_kernel(
666+
def triton_fp8_blockwise_weight_quant_transposed_rhs_kernel(
662667
x_ptr,
663668
x_stride_dim_0,
664669
x_stride_dim_1,
@@ -719,8 +724,9 @@ def fp8_blockwise_weight_quant_transposed_rhs_kernel(
719724
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask)
720725

721726

722-
def fp8_blockwise_weight_quant_transposed_rhs(
723-
x: torch.Tensor, block_size: int = 128, dtype=torch.float8_e4m3fn
727+
@triton_op("torchao::triton_fp8_blockwise_weight_quant_transposed_rhs", mutates_args={})
728+
def triton_fp8_blockwise_weight_quant_transposed_rhs(
729+
x: torch.Tensor, block_size: int = 128, dtype: torch.dtype = torch.float8_e4m3fn
724730
) -> Tuple[torch.Tensor, torch.Tensor]:
725731
assert x.is_contiguous(), "Input tensor must be contiguous"
726732
assert x.dim() == 2, "Input tensor must have 2 dimensions"
@@ -742,7 +748,7 @@ def fp8_blockwise_weight_quant_transposed_rhs(
742748
triton.cdiv(M, meta["BLOCK_SIZE"]),
743749
triton.cdiv(N, meta["BLOCK_SIZE"]),
744750
)
745-
fp8_blockwise_weight_quant_transposed_rhs_kernel[grid](
751+
wrap_triton(triton_fp8_blockwise_weight_quant_transposed_rhs_kernel)[grid](
746752
x,
747753
x.stride(0),
748754
x.stride(1),

torchao/prototype/blockwise_fp8_training/linear.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@
99

1010
from torchao.core.config import AOBaseConfig
1111
from torchao.prototype.blockwise_fp8_training.kernels import (
12-
fp8_blockwise_act_quant_lhs,
13-
fp8_blockwise_act_quant_rhs,
14-
fp8_blockwise_act_quant_transposed_lhs,
15-
fp8_blockwise_weight_quant_rhs,
16-
fp8_blockwise_weight_quant_transposed_rhs,
12+
triton_fp8_blockwise_act_quant_lhs,
13+
triton_fp8_blockwise_act_quant_rhs,
14+
triton_fp8_blockwise_act_quant_transposed_lhs,
15+
triton_fp8_blockwise_weight_quant_rhs,
16+
triton_fp8_blockwise_weight_quant_transposed_rhs,
1717
triton_fp8_gemm_1x128_128x1,
1818
triton_fp8_gemm_1x128_128x128,
1919
)
@@ -33,10 +33,10 @@ def forward(ctx, x, weight, block_size, out_dtype=torch.bfloat16, use_triton=Fal
3333
x = x.reshape(-1, x_orig_shape[-1])
3434

3535
# Cast inputs to fp8 blockwise using (1, block_size) scaling granularity in row major format.
36-
x_fp8, x_scale = fp8_blockwise_act_quant_lhs(x, block_size)
36+
x_fp8, x_scale = triton_fp8_blockwise_act_quant_lhs(x, block_size)
3737

3838
# Cast weight to fp8 blockwise using (block_size, block_size) scaling granularity, with transposed dims in column major format.
39-
weight_t_fp8, weight_t_scale = fp8_blockwise_weight_quant_transposed_rhs(
39+
weight_t_fp8, weight_t_scale = triton_fp8_blockwise_weight_quant_transposed_rhs(
4040
weight,
4141
block_size=block_size,
4242
)
@@ -74,13 +74,13 @@ def backward(ctx, grad_output):
7474
assert grad_output.shape[1] % 128 == 0, "unsupported"
7575

7676
# Cast grad_output to fp8 blockwise 1x128 since it is the grad of the output activation.
77-
grad_output_fp8, grad_output_scale = fp8_blockwise_act_quant_lhs(
77+
grad_output_fp8, grad_output_scale = triton_fp8_blockwise_act_quant_lhs(
7878
grad_output,
7979
block_size,
8080
)
8181

8282
# Cast weight to fp8 blockwise to 128x128 in column major format.
83-
weight_fp8, weight_scale = fp8_blockwise_weight_quant_rhs(
83+
weight_fp8, weight_scale = triton_fp8_blockwise_weight_quant_rhs(
8484
weight,
8585
block_size=block_size,
8686
)
@@ -100,15 +100,17 @@ def backward(ctx, grad_output):
100100
# Cast grad_output_t to fp8 blockwise with (1 x block_size) scaling groups, since it is
101101
# the grad of the output activation.
102102
# Write directly with transposed dims in row major format, as needed for dW calc.
103-
grad_output_t_fp8, grad_output_t_scale = fp8_blockwise_act_quant_transposed_lhs(
104-
grad_output,
105-
block_size,
103+
grad_output_t_fp8, grad_output_t_scale = (
104+
triton_fp8_blockwise_act_quant_transposed_lhs(
105+
grad_output,
106+
block_size,
107+
)
106108
)
107109

108110
# Cast x to fp8 blockwise with (block_size x 1) scaling groups, in column major format.
109111
# RHS should have groupwise scales calculated colwise, so scaling groups do not cross the
110112
# contracting (K) dim.
111-
x_fp8, x_scale = fp8_blockwise_act_quant_rhs(x, block_size)
113+
x_fp8, x_scale = triton_fp8_blockwise_act_quant_rhs(x, block_size)
112114

113115
# grad_weight = grad_output.T @ x
114116
fp8_gemm_1x128_128x1 = (

0 commit comments

Comments
 (0)