Skip to content

Commit

Permalink
Merge pull request #1367 from daineAMD/trsm-hotfix
Browse files Browse the repository at this point in the history
Hotifx - Guarding trsm kernel launch with 0 blocks
  • Loading branch information
daineAMD authored Nov 8, 2023
2 parents 99b0a71 + c7403c4 commit a03bb21
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 33 deletions.
7 changes: 4 additions & 3 deletions clients/gtest/trsm_gtest.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,10 @@ Definitions:
- { M: 1024, N: 1, lda: 2000, ldb: 1024 }

- &medium_matrix_size_range
- { M: 192, N: 192, lda: 192, ldb: 192 }
- { M: 600, N: 500, lda: 600, ldb: 600 }
- { M: 800, N: 700, lda: 801, ldb: 701 }
- { M: 129, N: 129, lda: 129, ldb: 129 }
- { M: 192, N: 192, lda: 192, ldb: 192 }
- { M: 600, N: 500, lda: 600, ldb: 600 }
- { M: 800, N: 700, lda: 801, ldb: 701 }

# - &small_substitution_size_range
- { M: 2, N: 1, lda: 30, ldb: 30 }
Expand Down
69 changes: 39 additions & 30 deletions library/src/blas3/trtri_trsm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,21 +188,26 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle,
static constexpr size_t sub_blockSize = 128;
size_t tri_elements_to_zero = rocblas_num_non_tri_elements(NB) * sub_blocks;
size_t num_sub_blocks = (tri_elements_to_zero + sub_blockSize - 1) / sub_blockSize;
ROCBLAS_LAUNCH_KERNEL((rocblas_trtri_fill<sub_blockSize, T>),
dim3(num_sub_blocks, batch_count),
dim3(sub_blockSize),
0,
handle->get_stream(),
handle,
uplo == rocblas_fill_lower ? rocblas_fill_upper : rocblas_fill_lower,
NB,
rocblas_num_non_tri_elements(NB),
NB,
NB * NB,
invA,
offset_invAin,
stride_invA,
sub_blocks);

dim3 grid_fill(num_sub_blocks, batch_count);
dim3 threads_fill(sub_blockSize);
ROCBLAS_LAUNCH_KERNEL_GRID(grid_fill,
(rocblas_trtri_fill<sub_blockSize, T>),
grid_fill,
threads_fill,
0,
handle->get_stream(),
handle,
uplo == rocblas_fill_lower ? rocblas_fill_upper
: rocblas_fill_lower,
NB,
rocblas_num_non_tri_elements(NB),
NB,
NB * NB,
invA,
offset_invAin,
stride_invA,
sub_blocks);

constexpr rocblas_int JB = IB * 4;
rocblas_stride sub_stride_A = NB * size_t(lda) + NB;
Expand Down Expand Up @@ -309,21 +314,25 @@ rocblas_status rocblas_trtri_trsm_template(rocblas_handle handle,
size_t tri_elements_to_zero = rocblas_num_non_tri_elements(rem);
size_t num_sub_blocks = (tri_elements_to_zero + sub_blockSize - 1) / sub_blockSize;

ROCBLAS_LAUNCH_KERNEL((rocblas_trtri_fill<sub_blockSize, T>),
dim3(num_sub_blocks, batch_count),
dim3(sub_blockSize),
0,
handle->get_stream(),
handle,
uplo == rocblas_fill_lower ? rocblas_fill_upper : rocblas_fill_lower,
rem,
rocblas_num_non_tri_elements(rem),
NB,
0,
invA,
sub_blocks * NB * size_t(NB) + offset_invAin,
stride_invA,
1);
dim3 grid_fill(num_sub_blocks, batch_count);
dim3 threads_fill(sub_blockSize);
ROCBLAS_LAUNCH_KERNEL_GRID(grid_fill,
(rocblas_trtri_fill<sub_blockSize, T>),
grid_fill,
threads_fill,
0,
handle->get_stream(),
handle,
uplo == rocblas_fill_lower ? rocblas_fill_upper
: rocblas_fill_lower,
rem,
rocblas_num_non_tri_elements(rem),
NB,
0,
invA,
sub_blocks * NB * size_t(NB) + offset_invAin,
stride_invA,
1);

if constexpr(BATCHED)
status = rocblas_internal_trtri_batched_template(
Expand Down

0 comments on commit a03bb21

Please sign in to comment.