Skip to content

Commit

Permalink
vulkan: Add N/2 and N/4 optimized paths in coopmat2 shader
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffbolznv committed Mar 10, 2025
1 parent 458c70a commit a768a65
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 31 deletions.
24 changes: 12 additions & 12 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1469,33 +1469,33 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t l_align, m_align, s_align;
if (device->coopmat2) {
// spec constants and tile sizes for non-quant matmul/matmul_id
l_warptile = { 256, 128, 256, 64 };
m_warptile = { 256, 128, 128, 64 };
s_warptile = { 128, 64, 64, 64 };
l_warptile = { 256, 128, 256, 64, 1 };
m_warptile = { 256, 128, 128, 64, 0 };
s_warptile = { 128, 64, 64, 64, 0 };
l_wg_denoms = {128, 256, 1 };
m_wg_denoms = {128, 128, 1 };
s_wg_denoms = { 64, 64, 1 };

// spec constants and tile sizes for quant matmul (non-Qi_K)
l_warptile_mmq = { 256, 64, 128, 64 };
m_warptile_mmq = { 256, 32, 64, 64 };
s_warptile_mmq = { 256, 32, 32, 128 };
l_warptile_mmq = { 256, 64, 128, 64, 1 };
m_warptile_mmq = { 256, 32, 64, 64, 0 };
s_warptile_mmq = { 256, 32, 32, 128, 0 };
l_mmq_wg_denoms = { 64, 128, 1 };
m_mmq_wg_denoms = { 32, 64, 1 };
s_mmq_wg_denoms = { 32, 32, 1 };

// spec constants and tile sizes for quant matmul (Qi_K)
l_warptile_mmq_k = { 256, 64, 128, 64 };
m_warptile_mmq_k = { 256, 32, 64, 64 };
s_warptile_mmq_k = { 256, 32, 32, 128 };
l_warptile_mmq_k = { 256, 64, 128, 64, 1 };
m_warptile_mmq_k = { 256, 32, 64, 64, 0 };
s_warptile_mmq_k = { 256, 32, 32, 128, 0 };
l_mmq_wg_denoms_k = { 64, 128, 1 };
m_mmq_wg_denoms_k = { 32, 64, 1 };
s_mmq_wg_denoms_k = { 32, 32, 1 };

// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 64, 16 };
m_warptile_mmqid = { 256, 128, 64, 16 };
s_warptile_mmqid = { 256, 128, 64, 16 };
l_warptile_mmqid = { 256, 128, 64, 16, 0 };
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
l_mmqid_wg_denoms = { 128, 64, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };
Expand Down
79 changes: 60 additions & 19 deletions ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ layout (constant_id = 1) const uint BM = 64;
layout (constant_id = 2) const uint BN = 64;
layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant

layout (constant_id = 4) const bool enable_smaller_matrices = false;
const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN;
const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN;

layout (push_constant) uniform parameter
{
uint M;
Expand Down Expand Up @@ -168,15 +172,13 @@ void main() {
const uint end_k = min(p.K, (ik + 1) * p.k_split);
#endif

coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);

#ifdef MUL_MAT_ID
uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K;
uint pos_b = 0;
#else
uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K;
uint pos_b = batch_idx * p.batch_stride_b;
uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
#endif

uint stride_a = p.stride_a / QUANT_K;
Expand All @@ -197,6 +199,7 @@ void main() {
tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);

#if QUANT_K > 1
tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K);
Expand Down Expand Up @@ -232,16 +235,54 @@ void main() {
tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1);

uint k_iters = (end_k - start_k + BK - 1) / BK;
if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {

for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;

coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose);

sum = coopMatMulAdd(mat_a, mat_b, sum);
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);

coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
return;
} else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {

coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;

coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose);

sum = coopMatMulAdd(mat_a, mat_b, sum);
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);

coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
return;
} else {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {

coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;

coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);

sum = coopMatMulAdd(mat_a, mat_b, sum);
}
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);

sum = coopMatMulAdd(mat_a, mat_b, sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
return;
}
} else
#endif // !defined(MUL_MAT_ID)
Expand All @@ -254,6 +295,9 @@ void main() {

tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);

coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);

[[dont_unroll]]
for (uint block_k = start_k; block_k < end_k; block_k += BK) {

Expand Down Expand Up @@ -296,19 +340,16 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
}
}

// Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
// Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);

#ifdef MUL_MAT_ID
// Call callback to store each element, remapping row through shared memory
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
// Call callback to store each element, remapping row through shared memory
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
#else
tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1);

uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z;
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
#endif
}
}

0 comments on commit a768a65

Please sign in to comment.