Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions benchmarks/prototype/blockwise_fp8_training/bench_linear_fwd_bwd.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@

import torch
from tabulate import tabulate
from torch.nn import functional as F
from tqdm import tqdm
from triton.testing import do_bench

Expand Down Expand Up @@ -72,7 +71,9 @@ def get_configs() -> List[ExperimentConfig]:
return configs


def run_experiment(config: ExperimentConfig, profile=False, use_compile=False) -> ExperimentResult:
def run_experiment(
config: ExperimentConfig, profile=False, use_compile=False
) -> ExperimentResult:
M, N, K = config.m, config.n, config.k
inputs = torch.randn(M, K, dtype=config.out_dtype, device="cuda")
bf16_linear = torch.nn.Linear(K, N, dtype=config.out_dtype, device="cuda")
Expand All @@ -87,24 +88,23 @@ def warmup(func, *args, **kwargs):
for _ in range(3):
func(*args, **kwargs)


# bfloat16 bench and profile
labels = inputs.new_empty(M, N).fill_(1.0)
bf16_linear_us = bench_fwd_bwd_microseconds(
bf16_linear,
inputs,
labels=labels,
bf16_linear,
inputs,
labels=labels,
use_compile=use_compile,
)
if profile:
print("Profiling bf16_linear")
profile_fwd_bwd(
bf16_linear,
inputs,
bf16_linear,
inputs,
labels=labels,
profile_name="bf16_linear_profile",
use_compile=use_compile,
)
)

# FP8 triton bench and profile
fp8_triton_linear_us = bench_fwd_bwd_microseconds(
Expand Down Expand Up @@ -189,7 +189,7 @@ def main(args: argparse.Namespace):


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser = argparse.ArgumentParser()
parser.add_argument("--profile", action="store_true", help="Enable profiling")
parser.add_argument("--compile", action="store_true", help="Enable compilation")
args = parser.parse_args()
Expand Down
75 changes: 48 additions & 27 deletions torchao/prototype/blockwise_fp8_training/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,22 @@ def triton_fp8_gemm_1x128_128x1(
num_stages=stages,
)
for warps in [4, 8]
for stages in [2, 4]
]

quant_kernel_configs_with_groups = [
triton.Config(
{"NUM_GROUPS": groups},
num_warps=warps,
num_stages=stages,
)
for groups in [2, 16, 32, 64, 128]
for warps in [2, 4, 8]
for stages in [2, 4, 6]
]


@triton.autotune(configs=quant_kernel_configs, key=["K"])
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
@triton.jit
def fp8_blockwise_act_quant_lhs_kernel(
x_ptr,
Expand All @@ -283,13 +294,14 @@ def fp8_blockwise_act_quant_lhs_kernel(
M,
K: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
NUM_GROUPS: tl.constexpr,
EPS: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_k = tl.program_id(axis=1)

# Load (1 x block_size) tile of x, where input is row major
m_offs = pid_m
# Load (num_groups x block_size) tile of x, where input is row major
m_offs = pid_m * NUM_GROUPS + tl.arange(0, NUM_GROUPS)
k_offs = pid_k * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1
x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K)
Expand All @@ -298,8 +310,10 @@ def fp8_blockwise_act_quant_lhs_kernel(
# Perform scaling
max_fp8_e4m3 = 448.0
min_fp8_e4m3 = -448.0
amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, max=float("inf")).to(tl.float64)
scale = (max_fp8_e4m3 / amax).to(tl.float32)

# Scales for (1 x block_size) groups, shape will be (NUM_GROUPS, 1)
amax = tl.clamp(tl.max(tl.abs(x), axis=1), min=EPS, max=float("inf")).to(tl.float64)
scale = (max_fp8_e4m3 / amax).to(tl.float32)[:, None]
y = x * scale
y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty)

Expand All @@ -309,7 +323,7 @@ def fp8_blockwise_act_quant_lhs_kernel(
tl.store(y_ptr + y_offs, y, mask=y_mask)

# Write reciprocal scales
scale_offs = pid_m * s_stride_dim_0 + pid_k * s_stride_dim_1
scale_offs = m_offs[:, None] * s_stride_dim_0 + pid_k * s_stride_dim_1
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale))


Expand All @@ -334,7 +348,10 @@ def fp8_blockwise_act_quant_lhs(
(M, K // block_size),
(1, M),
)
grid = lambda meta: (M, triton.cdiv(K, meta["BLOCK_SIZE"]))
grid = lambda meta: (
triton.cdiv(M, meta["NUM_GROUPS"]),
triton.cdiv(K, meta["BLOCK_SIZE"]),
)
fp8_blockwise_act_quant_lhs_kernel[grid](
x,
x.stride(0),
Expand All @@ -353,7 +370,7 @@ def fp8_blockwise_act_quant_lhs(
return y, s


@triton.autotune(configs=quant_kernel_configs, key=["K"])
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
@triton.jit
def fp8_blockwise_act_quant_rhs_kernel(
x_ptr,
Expand All @@ -368,33 +385,38 @@ def fp8_blockwise_act_quant_rhs_kernel(
M,
K: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
NUM_GROUPS: tl.constexpr,
EPS: tl.constexpr,
):
pid_m = tl.program_id(axis=0)
pid_k = tl.program_id(axis=1)

# Load (block_size x 1) tile of x, where input is row major
# Load (block_size x block_size) tile of x, where input is row major.
# Each scaling group is (block_size x 1), but we load (block_size x block_size)
# to facilitate coalesced gmem accesses and improve efficiency.
m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
k_offs = pid_k
k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS)
x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1
x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K)
x = tl.load(x_ptr + x_offs, mask=x_mask)

# Perform scaling
max_fp8_e4m3 = 448.0
min_fp8_e4m3 = -448.0
amax = tl.clamp(tl.max(tl.abs(x)), min=EPS, max=float("inf")).to(tl.float64)
scale = (max_fp8_e4m3 / amax).to(tl.float32)

# Column-wise scales for RHS operand, shape (1, block_size)
amax = tl.clamp(tl.max(tl.abs(x), axis=0), min=EPS, max=float("inf")).to(tl.float64)
scale = (max_fp8_e4m3 / amax).to(tl.float32)[None, :]
y = x * scale
y = tl.clamp(y, min=min_fp8_e4m3, max=max_fp8_e4m3).to(y_ptr.dtype.element_ty)

# Write output to column major fomrat
# Write output to column major format
y_offs = m_offs[:, None] * y_stride_dim_0 + k_offs[None, :] * y_stride_dim_1
y_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K)
tl.store(y_ptr + y_offs, y, mask=y_mask)

# Write scales
scale_offs = pid_m * s_stride_dim_0 + pid_k * s_stride_dim_1
scale_offs = pid_m * s_stride_dim_0 + k_offs[None, :] * s_stride_dim_1
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale))


Expand All @@ -420,7 +442,7 @@ def fp8_blockwise_act_quant_rhs(

grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_SIZE"]),
K,
triton.cdiv(K, meta["NUM_GROUPS"]),
)
fp8_blockwise_act_quant_rhs_kernel[grid](
x,
Expand All @@ -440,7 +462,7 @@ def fp8_blockwise_act_quant_rhs(
return y, s


@triton.autotune(configs=quant_kernel_configs, key=["K"])
@triton.autotune(configs=quant_kernel_configs_with_groups, key=["K"])
@triton.jit
def fp8_blockwise_act_quant_transposed_lhs_kernel(
x_ptr,
Expand All @@ -454,8 +476,8 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
s_stride_dim_1,
M,
K: tl.constexpr,
SCALE_BLOCK_SIZE: tl.constexpr, # For scaling groups, not for grid/parallelization
BLOCK_SIZE_K: tl.constexpr, # For grid/parallelization, not for scaling groups
BLOCK_SIZE: tl.constexpr, # For scaling groups, not for grid/parallelization
NUM_GROUPS: tl.constexpr, # For grid/parallelization, not for scaling groups
EPS: tl.constexpr,
):
# This kernel reads data in row-major format, and writes to an output tensor with
Expand All @@ -465,12 +487,12 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(
pid_m = tl.program_id(axis=0)
pid_k = tl.program_id(axis=1)

# Load (block_size x block_size_k) block of input, where input is row major.
# Load (block_size x num_groups) block of input, where input is row major.
# We will be computing (block_size x 1) scaling factors (columns), and computing
# `block_size_k` at a time, so we aren't parallelizing with 1 thread per column,
# `num_groups` at a time, so we aren't parallelizing with 1 thread per column,
# which will fail to launch for large tensors, due to max block number of 65535.
m_offs = pid_m * SCALE_BLOCK_SIZE + tl.arange(0, SCALE_BLOCK_SIZE)
k_offs = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
m_offs = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
k_offs = pid_k * NUM_GROUPS + tl.arange(0, NUM_GROUPS)
x_offs = m_offs[:, None] * x_stride_dim_0 + k_offs[None, :] * x_stride_dim_1
x_mask = (m_offs[:, None] < M) & (k_offs[None, :] < K)
x = tl.load(x_ptr + x_offs, mask=x_mask)
Expand All @@ -496,7 +518,7 @@ def fp8_blockwise_act_quant_transposed_lhs_kernel(

# Scale tensor size is (K, M // SCALE_BLOCK_SIZE)
scale_offs = scale_k_offs * s_stride_dim_0 + scale_m_off * s_stride_dim_1
scale_mask = (scale_k_offs < K) & (scale_m_off < M // SCALE_BLOCK_SIZE)
scale_mask = (scale_k_offs < K) & (scale_m_off < M // BLOCK_SIZE)

# Write out reciprocal scales
tl.store(s_ptr + scale_offs, tl.div_rn(1.0, scale), mask=scale_mask)
Expand Down Expand Up @@ -524,8 +546,8 @@ def fp8_blockwise_act_quant_transposed_lhs(
(1, K), # stride
)
grid = lambda meta: (
triton.cdiv(M, meta["SCALE_BLOCK_SIZE"]),
triton.cdiv(K, meta["BLOCK_SIZE_K"]),
triton.cdiv(M, meta["BLOCK_SIZE"]),
triton.cdiv(K, meta["NUM_GROUPS"]),
)

fp8_blockwise_act_quant_transposed_lhs_kernel[grid](
Expand All @@ -540,8 +562,7 @@ def fp8_blockwise_act_quant_transposed_lhs(
s.stride(1),
M,
K=K,
SCALE_BLOCK_SIZE=block_size, # Scaling group size
BLOCK_SIZE_K=block_size, # Just for parallelize the work along K as well
BLOCK_SIZE=block_size, # Scaling group size
EPS=EPS,
)
return y, s
Expand Down
Loading