From fddcde4fca8b6d38ba7c5bc257f8758797a9fcc4 Mon Sep 17 00:00:00 2001 From: louisfd Date: Thu, 19 Sep 2024 09:44:54 -0400 Subject: [PATCH] reuse lhs frag strategy --- .../cmma/compute_loop/accumulators_first.rs | 16 +++++++++++++-- .../cubecl-linalg/src/matmul/cmma/config.rs | 20 +++++++++++++------ 2 files changed, 28 insertions(+), 8 deletions(-) diff --git a/crates/cubecl-linalg/src/matmul/cmma/compute_loop/accumulators_first.rs b/crates/cubecl-linalg/src/matmul/cmma/compute_loop/accumulators_first.rs index 101b098b0..d037dfd9e 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/compute_loop/accumulators_first.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/compute_loop/accumulators_first.rs @@ -27,6 +27,7 @@ impl ComputeLoop for AllAccumulatorsFirstComputeLoop { let num_accumulators = comptime_info.num_accumulators; let num_buffers = block_size_k / tile_size; let num_coop_per_row = (block_size_n / tile_size) / num_accumulators; + let reuse_lhs_fragment = comptime_info.reuse_lhs_fragment; // Runtime values let tile_row = ids.coop / num_coop_per_row; @@ -34,14 +35,25 @@ impl ComputeLoop for AllAccumulatorsFirstComputeLoop { #[unroll(unroll)] for buffer_iter in 0..num_buffers { - #[unroll] - for accumulator_iter in 0..num_accumulators { + if reuse_lhs_fragment { load_into_fragment( tile_row * num_buffers + buffer_iter, shared_memories.lhs, &fragments.lhs, comptime_info, ); + } + + #[unroll] + for accumulator_iter in 0..num_accumulators { + if !reuse_lhs_fragment { + load_into_fragment( + tile_row * num_buffers + buffer_iter, + shared_memories.lhs, + &fragments.lhs, + comptime_info, + ); + } load_into_fragment( (tile_col_base + accumulator_iter) * num_buffers + buffer_iter, diff --git a/crates/cubecl-linalg/src/matmul/cmma/config.rs b/crates/cubecl-linalg/src/matmul/cmma/config.rs index 9af5cb4e7..c128b9a99 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/config.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/config.rs @@ -52,14 +52,16 @@ pub enum ComputeLoopOrderStrategy { /// Accumulators for one warp are put concurrently in a shared memory large enough to contain them all AllBuffersFirst, /// Accumulators for one warp are put sequentially in a shared memory with only one reusable spot - AllAccumulatorsFirst, + AllAccumulatorsFirst { reuse_lhs_fragment: bool }, } -impl From for u32 { +impl From for (u32, bool) { fn from(value: ComputeLoopOrderStrategy) -> Self { match value { - ComputeLoopOrderStrategy::AllBuffersFirst => 0, - ComputeLoopOrderStrategy::AllAccumulatorsFirst => 1, + ComputeLoopOrderStrategy::AllBuffersFirst => (0, false), + ComputeLoopOrderStrategy::AllAccumulatorsFirst { reuse_lhs_fragment } => { + (1, reuse_lhs_fragment) + } } } } @@ -116,6 +118,8 @@ impl CmmaConfig { pub(crate) fn comptime_info(&self, m: usize, k: usize, n: usize) -> ComptimeCmmaInfo { let num_coops = self.b_mn * self.b_k / (CMMA_TILE_SIZE * CMMA_TILE_SIZE); + let (compute_loop_order_strategy, reuse_lhs_fragment) = + self.compute_loop_order_strategy.into(); ComptimeCmmaInfo { block_size_m: self.b_mn as u32, @@ -131,7 +135,8 @@ impl CmmaConfig { num_accumulators: (self.b_mn / self.b_k) as u32, write_out_strategy: self.write_out_strategy.into(), cube_dispatch_strategy: self.cube_dispatch_strategy.into(), - compute_loop_order_strategy: self.compute_loop_order_strategy.into(), + compute_loop_order_strategy, + reuse_lhs_fragment, } } @@ -205,6 +210,9 @@ pub struct ComptimeCmmaInfo { pub write_out_strategy: u32, /// 0 = RowMajor, 1 = ColMajor, 2 = Swizzle pub cube_dispatch_strategy: u32, - /// 0 = buffer inner, 1 = buffer outer + /// 0 = all buffers first, 1 = all accumulators first pub compute_loop_order_strategy: u32, + /// Whether to reuse lhs fragment (true) or to reload it (false) + /// Available only with all accumulators first compute loop order + pub reuse_lhs_fragment: bool, }