Skip to content

Commit

Permalink
reuse lhs frag strategy (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Sep 19, 2024
1 parent 2fdfb51 commit 4530e3b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,33 @@ impl ComputeLoop for AllAccumulatorsFirstComputeLoop {
let num_accumulators = comptime_info.num_accumulators;
let num_buffers = block_size_k / tile_size;
let num_coop_per_row = (block_size_n / tile_size) / num_accumulators;
let reuse_lhs_fragment = comptime_info.reuse_lhs_fragment;

// Runtime values
let tile_row = ids.coop / num_coop_per_row;
let tile_col_base = (ids.coop % num_coop_per_row) * num_accumulators;

#[unroll(unroll)]
for buffer_iter in 0..num_buffers {
#[unroll]
for accumulator_iter in 0..num_accumulators {
if reuse_lhs_fragment {
load_into_fragment(
tile_row * num_buffers + buffer_iter,
shared_memories.lhs,
&fragments.lhs,
comptime_info,
);
}

#[unroll]
for accumulator_iter in 0..num_accumulators {
if !reuse_lhs_fragment {
load_into_fragment(
tile_row * num_buffers + buffer_iter,
shared_memories.lhs,
&fragments.lhs,
comptime_info,
);
}

load_into_fragment(
(tile_col_base + accumulator_iter) * num_buffers + buffer_iter,
Expand Down
20 changes: 14 additions & 6 deletions crates/cubecl-linalg/src/matmul/cmma/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,16 @@ pub enum ComputeLoopOrderStrategy {
/// Accumulators for one warp are put concurrently in a shared memory large enough to contain them all
AllBuffersFirst,
/// Accumulators for one warp are put sequentially in a shared memory with only one reusable spot
AllAccumulatorsFirst,
AllAccumulatorsFirst { reuse_lhs_fragment: bool },
}

impl From<ComputeLoopOrderStrategy> for u32 {
impl From<ComputeLoopOrderStrategy> for (u32, bool) {
fn from(value: ComputeLoopOrderStrategy) -> Self {
match value {
ComputeLoopOrderStrategy::AllBuffersFirst => 0,
ComputeLoopOrderStrategy::AllAccumulatorsFirst => 1,
ComputeLoopOrderStrategy::AllBuffersFirst => (0, false),
ComputeLoopOrderStrategy::AllAccumulatorsFirst { reuse_lhs_fragment } => {
(1, reuse_lhs_fragment)
}
}
}
}
Expand Down Expand Up @@ -116,6 +118,8 @@ impl CmmaConfig {

pub(crate) fn comptime_info(&self, m: usize, k: usize, n: usize) -> ComptimeCmmaInfo {
let num_coops = self.b_mn * self.b_k / (CMMA_TILE_SIZE * CMMA_TILE_SIZE);
let (compute_loop_order_strategy, reuse_lhs_fragment) =
self.compute_loop_order_strategy.into();

ComptimeCmmaInfo {
block_size_m: self.b_mn as u32,
Expand All @@ -131,7 +135,8 @@ impl CmmaConfig {
num_accumulators: (self.b_mn / self.b_k) as u32,
write_out_strategy: self.write_out_strategy.into(),
cube_dispatch_strategy: self.cube_dispatch_strategy.into(),
compute_loop_order_strategy: self.compute_loop_order_strategy.into(),
compute_loop_order_strategy,
reuse_lhs_fragment,
}
}

Expand Down Expand Up @@ -205,6 +210,9 @@ pub struct ComptimeCmmaInfo {
pub write_out_strategy: u32,
/// 0 = RowMajor, 1 = ColMajor, 2 = Swizzle
pub cube_dispatch_strategy: u32,
/// 0 = buffer inner, 1 = buffer outer
/// 0 = all buffers first, 1 = all accumulators first
pub compute_loop_order_strategy: u32,
/// Whether to reuse lhs fragment (true) or to reload it (false)
/// Available only with all accumulators first compute loop order
pub reuse_lhs_fragment: bool,
}

0 comments on commit 4530e3b

Please sign in to comment.