From e8de90ef2d2b4f3d0547a9f8ea49f32952d98c4c Mon Sep 17 00:00:00 2001 From: sjfeng Date: Sat, 6 Jul 2024 22:28:53 +0800 Subject: [PATCH] Fix cute gemm dispatch-5 when C != D --- include/cute/algorithm/gemm.hpp | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/include/cute/algorithm/gemm.hpp b/include/cute/algorithm/gemm.hpp index 27c322168a..e8dce656a3 100644 --- a/include/cute/algorithm/gemm.hpp +++ b/include/cute/algorithm/gemm.hpp @@ -411,7 +411,13 @@ gemm(MMA_Atom const& mma, CUTE_UNROLL for (int k = 0; k < K; ++k) { - gemm(mma, D, A(_,_,k), B(_,_,k), C); + if (k == 0) { + // D = Ak x Bk + C + gemm(mma, D, A(_,_,k), B(_,_,k), C); + } else { + // D = Ak x Bk + D + gemm(mma, D, A(_,_,k), B(_,_,k), D); + } } } @@ -493,7 +499,13 @@ gemm(MMA_Atom const& mma, copy(A(_,_,k), rA(_,_,k)); copy(B(_,_,k), rB(_,_,k)); // Thread-level register gemm for k - gemm(mma, D, rA(_,_,k), rB(_,_,k), C); + if (k == 0) { + // D = Ak x Bk + C + gemm(mma, D, rA(_,_,k), rB(_,_,k), C); + } else { + // D = Ak x Bk + D + gemm(mma, D, rA(_,_,k), rB(_,_,k), D); + } } }