diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py new file mode 100644 index 0000000000..000b6d3326 --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x128_gemms.py @@ -0,0 +1,197 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm +from triton.testing import do_bench + +from torchao.prototype.blockwise_fp8_training.kernels import ( + blockwise_fp8_gemm_1x128_128x128, + fp8_blockwise_act_quant_lhs, + fp8_blockwise_weight_quant_transposed_rhs, +) + +device = torch.device("cuda") + +# This benchmark requires CUDA 12.9+ +assert torch.version.cuda is not None, "CUDA is not available" +cuda_major, cuda_minor = map(int, torch.version.cuda.split(".")) +assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required" + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + out_dtype: torch.dtype + m: int + n: int + k: int + + +@dataclass(frozen=True) +class ExperimentResult: + bf16_mm_us: float + fp8_triton_us: float + fp8_scaled_mm_us: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + mnk_list = [ + # Llama4 shapes + (16640, 5120, 8192), + (16640, 8192, 5120), + ] + out_dtypes = [torch.float32, torch.bfloat16] + configs = [] + for mnk, out_dtype in itertools.product(mnk_list, out_dtypes): + m, n, k = mnk + configs.append( + ExperimentConfig( + out_dtype=out_dtype, + m=m, + n=n, + k=k, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + # Simulate `grad_input = grad_output @ weight` + M, N, K = config.m, config.n, config.k + A = torch.randn(M, K, dtype=config.out_dtype, device="cuda") + B = torch.randn(N, K, dtype=config.out_dtype, device="cuda") + A_q, A_s = fp8_blockwise_act_quant_lhs(A, dtype=torch.float8_e4m3fn) + B_t_q, B_t_s = fp8_blockwise_weight_quant_transposed_rhs( + B, dtype=torch.float8_e4m3fn + ) + + def warmup(func, *args, **kwargs): + for _ in range(10): + func(*args, **kwargs) + + # Warmup then run bf16 torch.mm + warmup(torch.mm, A, B.t()) + + bf16_mm_us = benchmark_cuda_function_in_microseconds(torch.mm, A, B.t()) + + # Warm up then run triton bench + warmup( + blockwise_fp8_gemm_1x128_128x128, + A_q, + 1.0 / A_s, + B_t_q, + 1.0 / B_t_s, + ) + + fp8_triton_us = benchmark_cuda_function_in_microseconds( + blockwise_fp8_gemm_1x128_128x128, + A_q, + 1.0 / A_s, + B_t_q, + 1.0 / B_t_s, + ) + + # Warm up then run torch bench + # scaled_mm requires A_s and B_t_s be in column-major format + A_s = A_s.t().contiguous().t() + + warmup( + torch._scaled_mm, + A_q, + B_t_q, + 1.0 / A_s, + 1.0 / B_t_s, + out_dtype=config.out_dtype, + ) + + fp8_scaled_mm_us = benchmark_cuda_function_in_microseconds( + torch._scaled_mm, + A_q, + B_t_q, + 1.0 / A_s, + 1.0 / B_t_s, + out_dtype=config.out_dtype, + ) + + return ExperimentResult( + bf16_mm_us=bf16_mm_us, + fp8_triton_us=fp8_triton_us, + fp8_scaled_mm_us=fp8_scaled_mm_us, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "M", + "N", + "K", + "out_dtype", + "bf16_mm_us", + "fp8_triton_us", + "fp8_scaled_mm_us", + "bf16 tflops/sec", + "triton tflops/sec", + "scaled_mm tflops/sec", + ] + rows = [] + for experiment in experiments: + m, n, k = experiment.config.m, experiment.config.n, experiment.config.k + flops = 2 * m * n * k + bf16_mm_tflops_per_sec = (flops / 1e12) / (experiment.result.bf16_mm_us / 1e6) + triton_tflops_per_sec = (flops / 1e12) / (experiment.result.fp8_triton_us / 1e6) + scaled_mm_tflops_per_sec = (flops / 1e12) / ( + experiment.result.fp8_scaled_mm_us / 1e6 + ) + rows.append( + [ + m, + n, + k, + experiment.config.out_dtype, + experiment.result.bf16_mm_us, + experiment.result.fp8_triton_us, + experiment.result.fp8_scaled_mm_us, + bf16_mm_tflops_per_sec, + triton_tflops_per_sec, + scaled_mm_tflops_per_sec, + ] + ) + print(tabulate(rows, headers=headers)) + + +def benchmark_cuda_function_in_microseconds(f, *args, **kwargs): + return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3 + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py new file mode 100644 index 0000000000..6873ee2eae --- /dev/null +++ b/benchmarks/prototype/blockwise_fp8_training/bench_1x128_128x1_gemms.py @@ -0,0 +1,195 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. +# this benchmarking script is a modified version of the original script from: https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/utils/benchmark.py + +import itertools +from dataclasses import dataclass +from typing import List + +import torch +from tabulate import tabulate +from tqdm import tqdm +from triton.testing import do_bench + +from torchao.prototype.blockwise_fp8_training.kernels import ( + blockwise_fp8_gemm_1x128_128x1, + fp8_blockwise_act_quant_rhs, + fp8_blockwise_act_quant_transposed_lhs, +) + +device = torch.device("cuda") + +# This benchmark requires CUDA 12.9+ +assert torch.version.cuda is not None, "CUDA is not available" +cuda_major, cuda_minor = map(int, torch.version.cuda.split(".")) +assert cuda_major >= 12 and cuda_minor >= 9, "CUDA 12.9+ is required" + +# Needed since changing args to function causes recompiles +torch._dynamo.config.cache_size_limit = 1000 + + +@dataclass(frozen=True) +class ExperimentConfig: + out_dtype: torch.dtype + m: int + n: int + k: int + + +@dataclass(frozen=True) +class ExperimentResult: + bf16_mm_us: float + fp8_triton_us: float + fp8_scaled_mm_us: float + + +@dataclass(frozen=True) +class Experiment: + config: ExperimentConfig + result: ExperimentResult + + +def get_configs() -> List[ExperimentConfig]: + mnk_list = [ + # Llama4 shapes + (16640, 5120, 8192), + (16640, 8192, 5120), + ] + out_dtypes = [torch.float32, torch.bfloat16] + configs = [] + for mnk, out_dtype in itertools.product(mnk_list, out_dtypes): + m, n, k = mnk + configs.append( + ExperimentConfig( + out_dtype=out_dtype, + m=m, + n=n, + k=k, + ) + ) + return configs + + +def run_experiment(config: ExperimentConfig) -> ExperimentResult: + # Simulate `grad_weight = grad_output_t @ input` + M, N, K = config.m, config.n, config.k + A = torch.randn(M, N, dtype=config.out_dtype, device="cuda") + B = torch.randn(M, K, dtype=config.out_dtype, device="cuda") + A_t_q, A_t_s = fp8_blockwise_act_quant_transposed_lhs(A, dtype=torch.float8_e4m3fn) + B_q, B_s = fp8_blockwise_act_quant_rhs(B, dtype=torch.float8_e4m3fn) + + def warmup(func, *args, **kwargs): + for _ in range(10): + func(*args, **kwargs) + + # Warmup then run bf16 torch.mm + warmup(torch.mm, A.t(), B) + + bf16_mm_us = benchmark_cuda_function_in_microseconds(torch.mm, A.t(), B) + + # Warm up then run triton bench + warmup( + blockwise_fp8_gemm_1x128_128x1, + A_t_q, + 1.0 / A_t_s, + B_q, + 1.0 / B_s, + ) + + fp8_triton_us = benchmark_cuda_function_in_microseconds( + blockwise_fp8_gemm_1x128_128x1, + A_t_q, + 1.0 / A_t_s, + B_q, + 1.0 / B_s, + ) + + # torch._scaled_mm requires A_s and B_t_s be in column-major format + A_t_s = A_t_s.t().contiguous().t() + + # Warm up then run torch bench + warmup( + torch._scaled_mm, + A_t_q, + B_q, + 1.0 / A_t_s, + 1.0 / B_s, + out_dtype=config.out_dtype, + ) + + fp8_scaled_mm_us = benchmark_cuda_function_in_microseconds( + torch._scaled_mm, + A_t_q, + B_q, + 1.0 / A_t_s, + 1.0 / B_s, + out_dtype=config.out_dtype, + ) + + return ExperimentResult( + bf16_mm_us=bf16_mm_us, + fp8_triton_us=fp8_triton_us, + fp8_scaled_mm_us=fp8_scaled_mm_us, + ) + + +def print_results(experiments: List[Experiment]): + headers = [ + "M", + "N", + "K", + "out_dtype", + "bf16_mm_us", + "fp8_triton_us", + "fp8_scaled_mm_us", + "bf16 tflops/sec", + "triton tflops/sec", + "scaled_mm tflops/sec", + ] + rows = [] + for experiment in experiments: + m, n, k = experiment.config.m, experiment.config.n, experiment.config.k + flops = 2 * m * n * k + bf16_mm_tflops_per_sec = (flops / 1e12) / (experiment.result.bf16_mm_us / 1e6) + triton_tflops_per_sec = (flops / 1e12) / (experiment.result.fp8_triton_us / 1e6) + scaled_mm_tflops_per_sec = (flops / 1e12) / ( + experiment.result.fp8_scaled_mm_us / 1e6 + ) + rows.append( + [ + m, + n, + k, + experiment.config.out_dtype, + experiment.result.bf16_mm_us, + experiment.result.fp8_triton_us, + experiment.result.fp8_scaled_mm_us, + bf16_mm_tflops_per_sec, + triton_tflops_per_sec, + scaled_mm_tflops_per_sec, + ] + ) + print(tabulate(rows, headers=headers)) + + +def benchmark_cuda_function_in_microseconds(f, *args, **kwargs): + return do_bench(lambda: f(*args, **kwargs), return_mode="median") * 1e3 + + +def main(): + torch.random.manual_seed(123) + configs = get_configs() + results = [] + for config in tqdm(configs): + result = run_experiment(config) + results.append(Experiment(config=config, result=result)) + + # Use Tabulate to print results + print_results(results) + + +if __name__ == "__main__": + main() diff --git a/torchao/prototype/blockwise_fp8_training/kernels.py b/torchao/prototype/blockwise_fp8_training/kernels.py index a0b29be541..515886ec1d 100644 --- a/torchao/prototype/blockwise_fp8_training/kernels.py +++ b/torchao/prototype/blockwise_fp8_training/kernels.py @@ -10,20 +10,20 @@ import triton import triton.language as tl +from torchao.prototype.moe_training.utils import ( + _is_column_major, + _is_row_major, +) + fp8_gemm_configs_max_autotune = [ - # Small - triton.Config({"BLOCK_SIZE_M": 32, "BLOCK_SIZE_N": 64}, num_warps=2), - # Medium - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 128}, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 64}, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 64, "BLOCK_SIZE_N": 256}, num_warps=8), - # Large - triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 64}, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 128}, num_warps=8), - triton.Config({"BLOCK_SIZE_M": 128, "BLOCK_SIZE_N": 256}, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_warps=4), - triton.Config({"BLOCK_SIZE_M": 256, "BLOCK_SIZE_N": 128}, num_warps=8), + triton.Config( + {"BLOCK_SIZE_M": block_size, "BLOCK_SIZE_N": block_size}, + num_warps=num_warps, + num_stages=num_stages, + ) + for block_size in [64, 128, 256] + for num_warps in [4, 8] + for num_stages in [2, 4] ] # For fast compile times during development. @@ -57,6 +57,7 @@ def blockwise_fp8_gemm_1x128_128x128_kernel( M, N: tl.constexpr, K: tl.constexpr, + out_dtype: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, @@ -81,18 +82,16 @@ def blockwise_fp8_gemm_1x128_128x128_kernel( a_s_base_ptr = a_s_ptr + offs_m * a_s_stride_dim_0 b_s_base_ptr = b_s_ptr + (offs_n // BLOCK_SIZE_K) * b_s_stride_dim_1 accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) + b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N) for k in range(0, k_num_blocks): - a_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K) a = tl.load(a_ptrs, mask=a_mask, other=0.0) - - b_mask = (offs_k[:, None] < K) & (offs_n[None, :] < N) b = tl.load(b_ptrs, mask=b_mask, other=0.0) # Reciprocal scales to scale back to dynamic range of output dtype a_s = tl.load(a_s_base_ptr + k * a_s_stride_dim_1) b_s = tl.load(b_s_base_ptr + k * b_s_stride_dim_0) - - accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + accumulator += tl.dot(a, b) * a_s[:, None] * b_s a_ptrs += BLOCK_SIZE_K * a_stride_dim_1 b_ptrs += BLOCK_SIZE_K * b_stride_dim_0 @@ -109,14 +108,22 @@ def blockwise_fp8_gemm_1x128_128x128( b: torch.Tensor, # (K, N) b_s: torch.Tensor, # (K // block_size, N // block_size) block_size: int = 128, + out_dtype: torch.dtype = torch.float32, ): # 'a' must be in row-major layout, 'b' must be in column-major layout - assert a.is_contiguous() and not b.is_contiguous() - assert a_s.is_contiguous() and b_s.is_contiguous() + assert _is_row_major(a) and _is_column_major(b), ( + "a must be row-major, b must be column-major" + ) + + # a_scales must be row-major, b_scales must be column-major + assert _is_row_major(a_s) and _is_column_major(b_s), ( + "a_s must be row-major, b_s must be column-major" + ) + M = a.size(0) K = a.size(1) N = b.size(1) - c = a.new_empty(M, N, dtype=torch.bfloat16) + c = a.new_empty(M, N, dtype=out_dtype) grid = lambda META: ( triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]), @@ -140,6 +147,7 @@ def blockwise_fp8_gemm_1x128_128x128( M, N, K, + out_dtype=out_dtype, BLOCK_SIZE_K=block_size, ) return c @@ -217,6 +225,7 @@ def blockwise_fp8_gemm_1x128_128x1( b: torch.Tensor, # (K, N) b_s: torch.Tensor, # (K // block_size, N) reciprocals of scales block_size: int = 128, + out_dtype: torch.dtype = torch.float32, ): # 'a' must be in row-major layout, 'b' must be in column-major layout assert a.is_contiguous() and not b.is_contiguous() @@ -224,7 +233,7 @@ def blockwise_fp8_gemm_1x128_128x1( M = a.size(0) K = a.size(1) N = b.size(1) - c = a.new_empty(M, N, dtype=torch.bfloat16) + c = a.new_empty(M, N, dtype=out_dtype) grid = lambda META: ( triton.cdiv(M, META["BLOCK_SIZE_M"]), triton.cdiv(N, META["BLOCK_SIZE_N"]), @@ -674,8 +683,10 @@ def fp8_blockwise_weight_quant_transposed_rhs( M, N = x.size() y = torch.empty(N, M, dtype=dtype, device=x.device) y = y.as_strided(y.size(), (1, y.size(0))) # Column major - s = x.new_empty( - triton.cdiv(N, block_size), triton.cdiv(M, block_size), dtype=torch.float32 + n_blocks, m_blocks = triton.cdiv(N, block_size), triton.cdiv(M, block_size) + s = x.new_empty(n_blocks, m_blocks, dtype=torch.float32).as_strided( + (n_blocks, m_blocks), # shape + (1, n_blocks), # stride ) grid = lambda meta: ( triton.cdiv(M, meta["BLOCK_SIZE"]), diff --git a/torchao/prototype/moe_training/utils.py b/torchao/prototype/moe_training/utils.py index dc13dfea33..ab648280ea 100644 --- a/torchao/prototype/moe_training/utils.py +++ b/torchao/prototype/moe_training/utils.py @@ -290,7 +290,21 @@ def _is_column_major(x: torch.Tensor) -> bool: A boolean indicating whether the input tensor is column-major. """ assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D" - return x.stride(-2) == 1 and x.stride(-1) > 1 + return x.stride(-2) == 1 + + +def _is_row_major(x: torch.Tensor) -> bool: + """ + This function checks if the input tensor is row-major. + + Args: + x (torch.Tensor): The input tensor to be checked. + + Returns: + A boolean indicating whether the input tensor is row-major. + """ + assert x.ndim == 2 or x.ndim == 3, "input tensor must be 2D or 3D" + return x.stride(-1) == 1 def generate_jagged_offs(E, M, multiple_of=16, dtype=torch.int32, device="cuda"):