Skip to content

Commit

Permalink
Cmma: new strategy for num compute planes + many refactors (#150)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Sep 26, 2024
1 parent d509fbd commit d83afb2
Show file tree
Hide file tree
Showing 27 changed files with 452 additions and 295 deletions.
8 changes: 4 additions & 4 deletions crates/cubecl-cuda/src/runtime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,16 @@ fn register_wmma_features(features: &mut FeatureSet, arch: u32) {
b,
c,
m: 32,
k: 8,
n: 16,
k: 16,
n: 8,
});
features.register(Feature::Cmma {
a,
b,
c,
m: 8,
k: 32,
n: 16,
k: 16,
n: 32,
});
}
}
Expand Down
37 changes: 37 additions & 0 deletions crates/cubecl-linalg/src/matmul/cmma/availability.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use cubecl_core::{
client::ComputeClient,
ir::{Elem, FloatKind},
Feature, Runtime,
};

use crate::matmul::cmma::config::CmmaConfig;

use super::config::TileDimension;

#[derive(Debug)]
pub enum UnavailabilityReason {
HighlyPermutatedInput,
SharedMemoryLimitBusted,
InvalidConfig(String),
CmmaInstructionsUnsupported,
}

/// Checks if the matmul cmma can be used.
pub fn check_cmma_availability<R: Runtime>(
client: &ComputeClient<R::Server, R::Channel>,
cmma_config: &CmmaConfig,
) -> Result<(), UnavailabilityReason> {
let tile_dim: TileDimension = cmma_config.tile_dimension_strategy.into();
if !client.features().enabled(Feature::Cmma {
a: Elem::Float(FloatKind::F16),
b: Elem::Float(FloatKind::F16),
c: Elem::Float(FloatKind::F32),
m: tile_dim.m as u8,
k: tile_dim.k as u8,
n: tile_dim.n as u8,
}) {
return Err(UnavailabilityReason::CmmaInstructionsUnsupported);
}

Ok(())
}
4 changes: 2 additions & 2 deletions crates/cubecl-linalg/src/matmul/cmma/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ pub fn cmma_launch<F: Float, FC: Float>(
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
match comptime_info.main_loop_strategy {
MainLoopStrategy::Standard(_) => {
MainLoopStrategy::Standard => {
cmma_build_step_1::<StandardMainLoop, F, FC>(lhs, rhs, out, comptime_info)
}
MainLoopStrategy::Split(_, _) => {
MainLoopStrategy::Split(_) => {
cmma_build_step_1::<SplitMainLoop, F, FC>(lhs, rhs, out, comptime_info)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::matmul::cmma::{
get_smem_position_lhs, get_smem_position_rhs, load_tile_into_fragment, ComputeLoop,
},
config::ComptimeCmmaInfo,
load_shared_memory::load_info::{LhsLoadInfo, RhsLoadInfo},
};

pub(crate) struct AccumulatorsFirstComputeLoop {}
Expand All @@ -21,30 +22,28 @@ impl ComputeLoop for AccumulatorsFirstComputeLoop {
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
// Comptime values
let block_size_k = comptime_info.block_size_k;
let block_size_n = comptime_info.block_size_n;
let tile_size = comptime_info.tile_size;
let unroll = comptime_info.unroll;
let num_accumulators = comptime_info.num_accumulators;
let num_buffers = block_size_k / tile_size;
let num_coop_per_row = (block_size_n / tile_size) / num_accumulators;
let num_buffers = comptime_info.num_buffers;
let num_planes_per_row = (block_size_n / comptime_info.tile_size_n) / num_accumulators;

// Runtime values
let tile_row = ids.coop / num_coop_per_row;
let tile_col_base = (ids.coop % num_coop_per_row) * num_accumulators;
let tile_row = ids.plane / num_planes_per_row;
let tile_col_base = (ids.plane % num_planes_per_row) * num_accumulators;

#[unroll(unroll)]
for buffer_iter in 0..num_buffers {
#[unroll]
for accumulator_iter in 0..num_accumulators {
load_tile_into_fragment(
load_tile_into_fragment::<FC, LhsLoadInfo>(
get_smem_position_lhs::<F, FC>(tile_row, buffer_iter, comptime_info),
shared_memories.lhs,
&fragments.lhs,
comptime_info,
);

load_tile_into_fragment(
load_tile_into_fragment::<FC, RhsLoadInfo>(
get_smem_position_rhs::<F, FC>(
buffer_iter,
tile_col_base + accumulator_iter,
Expand Down Expand Up @@ -76,21 +75,19 @@ impl ComputeLoop for AccumulatorsFirstWithReuseComputeLoop {
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
// Comptime values
let block_size_k = comptime_info.block_size_k;
let block_size_n = comptime_info.block_size_n;
let tile_size = comptime_info.tile_size;
let unroll = comptime_info.unroll;
let num_accumulators = comptime_info.num_accumulators;
let num_buffers = block_size_k / tile_size;
let num_coop_per_row = (block_size_n / tile_size) / num_accumulators;
let num_buffers = comptime_info.num_buffers;
let num_planes_per_row = (block_size_n / comptime_info.tile_size_n) / num_accumulators;

// Runtime values
let tile_row = ids.coop / num_coop_per_row;
let tile_col_base = (ids.coop % num_coop_per_row) * num_accumulators;
let tile_row = ids.plane / num_planes_per_row;
let tile_col_base = (ids.plane % num_planes_per_row) * num_accumulators;

#[unroll(unroll)]
for buffer_iter in 0..num_buffers {
load_tile_into_fragment(
load_tile_into_fragment::<FC, LhsLoadInfo>(
get_smem_position_lhs::<F, FC>(tile_row, buffer_iter, comptime_info),
shared_memories.lhs,
&fragments.lhs,
Expand All @@ -99,7 +96,7 @@ impl ComputeLoop for AccumulatorsFirstWithReuseComputeLoop {

#[unroll]
for accumulator_iter in 0..num_accumulators {
load_tile_into_fragment(
load_tile_into_fragment::<FC, RhsLoadInfo>(
get_smem_position_rhs::<F, FC>(
buffer_iter,
tile_col_base + accumulator_iter,
Expand Down
7 changes: 3 additions & 4 deletions crates/cubecl-linalg/src/matmul/cmma/compute_loop/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,17 @@ pub(crate) trait ComputeLoop {
}

#[cube]
pub(crate) fn load_tile_into_fragment<FC: Float>(
pub(crate) fn load_tile_into_fragment<FC: Float, I: LoadInfo>(
nth_tile: u32,
smem: SharedMemory<FC>,
fragment: &cmma::Matrix<FC>,
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
let tile_size = comptime_info.tile_size;
let smem_stride = tile_size * tile_size;
let smem_stride = I::num_tile_elements(comptime_info);

let smem_pos = nth_tile * smem_stride;
let slice = smem.slice(smem_pos, smem_pos + smem_stride);
cmma::load::<FC>(fragment, slice, 16);
cmma::load::<FC>(fragment, slice, I::tile_width(comptime_info));
}

#[cube]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use crate::matmul::cmma::{
get_smem_position_lhs, get_smem_position_rhs, load_tile_into_fragment, ComputeLoop,
},
config::ComptimeCmmaInfo,
load_shared_memory::load_info::{LhsLoadInfo, RhsLoadInfo},
};

pub(crate) struct BuffersFirstComputeLoop {}
Expand All @@ -20,30 +21,28 @@ impl ComputeLoop for BuffersFirstComputeLoop {
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
// Comptime values
let block_size_k = comptime_info.block_size_k;
let block_size_n = comptime_info.block_size_n;
let tile_size = comptime_info.tile_size;
let unroll = comptime_info.unroll;
let num_accumulators = comptime_info.num_accumulators;
let num_buffers = block_size_k / tile_size;
let num_coop_per_row = (block_size_n / tile_size) / num_accumulators;
let num_buffers = comptime_info.num_buffers;
let num_planes_per_row = (block_size_n / comptime_info.tile_size_n) / num_accumulators;

// Runtime values
let tile_row = compute_ids.coop / num_coop_per_row;
let tile_col_base = (compute_ids.coop % num_coop_per_row) * num_accumulators;
let tile_row = compute_ids.plane / num_planes_per_row;
let tile_col_base = (compute_ids.plane % num_planes_per_row) * num_accumulators;

#[unroll]
for accumulator_iter in 0..num_accumulators {
#[unroll(unroll)]
for buffer_iter in 0..num_buffers {
load_tile_into_fragment(
load_tile_into_fragment::<FC, LhsLoadInfo>(
get_smem_position_lhs::<F, FC>(tile_row, buffer_iter, comptime_info),
shared_memories.lhs,
&fragments.lhs,
comptime_info,
);

load_tile_into_fragment(
load_tile_into_fragment::<FC, RhsLoadInfo>(
get_smem_position_rhs::<F, FC>(
buffer_iter,
tile_col_base + accumulator_iter,
Expand Down
Loading

0 comments on commit d83afb2

Please sign in to comment.