-
import triton
import triton.language as tl
import torch
@triton.jit
def matmul_kernel(
# Pointers to matrices
A_ptr, B_ptr, C_ptr,
# Matrix dimensions
B, M, N, K,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_ab, stride_am, stride_ak,
stride_bb, stride_bk, stride_bn,
stride_cb, stride_cm, stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
ACTIVATION: tl.constexpr,
):
pid = tl.program_id(axis=0)
offs_b = tl.program_id(axis=1)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
A_ptrs = A_ptr + (offs_b * stride_ab + offs_m[:, None] * stride_am + offs_k[None, :] * stride_ak)
B_ptrs = B_ptr + (offs_b * stride_bb + offs_k[:, None] * stride_bk + offs_n[None, :] * stride_bn)
a = tl.load(A_ptrs)
# initialize and iteratively update accumulator
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float64)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
a = tl.load(A_ptrs, mask=(offs_k[None, :] < K - k * BLOCK_SIZE_K) & (offs_b < B), other=0.0)
b = tl.load(B_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_b < B), other=0.0)
acc += tl.dot(a, b, allow_tf32=False)
A_ptrs += BLOCK_SIZE_K * stride_ak
B_ptrs += BLOCK_SIZE_K * stride_bk
c = acc.to(tl.float32)
C_ptr = C_ptr + (offs_b * stride_cb + offs_m[:, None] * stride_cm + offs_n[None, :] * stride_cn)
c_mask = (offs_b < B) & (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(C_ptr, c, mask=c_mask)
def matmul(a, b, activation=""):
# Check constraints.
assert a.shape[2] == b.shape[1], "Incompatible dimensions"
assert a.is_contiguous(), "Matrix A must be contiguous"
assert b.is_contiguous(), "Matrix B must be contiguous"
B, M, K = a.shape
B, K, N = b.shape
# Allocates output.
c = torch.empty((B, M, N), device=a.device, dtype=a.dtype)
batch_size_n = triton.next_power_of_2(B)
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), batch_size_n)
matmul_kernel[grid](
a, b, c,
B, M, N, K,
a.stride(0), a.stride(1), a.stride(2),
b.stride(0), b.stride(1), b.stride(2),
c.stride(0), c.stride(1), c.stride(2),
BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32,
GROUP_SIZE_M=8,
ACTIVATION=activation #
)
return c
torch.manual_seed(0)
# a = torch.randint(1, 10, (8, 512, 512), device='cuda', dtype=torch.float32)
# b = torch.randint(1, 10, (8, 512, 512), device='cuda', dtype=torch.float32)
a = torch.randn((8, 512, 512), device='cuda', dtype=torch.float32)
b = torch.randn((8, 512, 512), device='cuda', dtype=torch.float32)
tolerance = 1e-4
torch_output = torch.bmm(a, b)
triton_output = matmul(a, b, activation=None)
print(f"triton_output={triton_output}")
print(f"torch_output={torch_output}")
if torch.allclose(triton_output, torch_output, atol=tolerance, rtol=0):
print("✅ Triton and Torch match")
else:
print("❌ Triton and Torch differ")
print("--" * 36)
absolute_diff = torch.abs(torch_output - triton_output)
non_equal_indices = torch.nonzero(absolute_diff > tolerance)
for index in non_equal_indices:
idx = tuple(index.tolist())
print(f"Index: {idx}, Value in torch: {torch_output[idx]}, Value in triton: {triton_output[idx]}") the output is: ❌ Triton and Torch differ
------------------------------------------------------------------------
Index: (0, 74, 465), Value in torch: 71.4826889038086, Value in triton: 71.48257446289062
Index: (0, 381, 249), Value in torch: 68.71340942382812, Value in triton: 68.71351623535156
Index: (1, 247, 145), Value in torch: -79.78231048583984, Value in triton: -79.78219604492188
Index: (6, 384, 160), Value in torch: -80.67752075195312, Value in triton: -80.67762756347656
Index: (6, 504, 126), Value in torch: 74.65267944335938, Value in triton: 74.65278625488281 Why are the matrix multiplications of tirtion and torch different in accuracy for float data? |
Beta Was this translation helpful? Give feedback.
Answered by
gujiewen
Mar 13, 2024
Replies: 1 comment 1 reply
-
get right answer when using ''acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) |
Beta Was this translation helpful? Give feedback.
1 reply
Answer selected by
Didi-cH
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
get right answer when using ''acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)