Skip to content

Commit

Permalink
getri fixes 2 (#128)
Browse files Browse the repository at this point in the history
* Corrections to trtri workspace size and getri block size

* Increased max number of threads for getri

* Restored original number of threads
  • Loading branch information
tfalders authored Aug 6, 2020
1 parent 35765bd commit dde7521
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 6 deletions.
2 changes: 1 addition & 1 deletion rocsolver/clients/gtest/getri_gtest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ const vector<vector<int>> matrix_size_range = {

// for daily_lapack tests
const vector<vector<int>> large_matrix_size_range = {
{192, 192}, {640, 640}, {1000, 1024}, {1200, 1230}
{192, 192}, {500, 600}, {640, 640}, {1000, 1024}, {1200, 1230}
};


Expand Down
8 changes: 3 additions & 5 deletions rocsolver/library/src/auxiliary/rocauxiliary_trtri.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ __global__ void trtri_kernel(const rocblas_diagonal diag, const rocblas_int n,
{
int b = hipBlockIdx_x;

rocblas_stride strideW = (n <= TRTRI_SWITCHSIZE_MID ? n : TRTRI_BLOCKSIZE);
rocblas_stride strideW = n;
T* a = load_ptr_batch<T>(A,b,shiftA,strideA);
T* w = load_ptr_batch<T>(work,b,0,strideW);

Expand Down Expand Up @@ -170,10 +170,8 @@ void rocsolver_trtri_getMemorySize(const rocblas_int n, const rocblas_int batch_
*size_1 = sizeof(T)*3;

// for workspace
if (n <= TRTRI_SWITCHSIZE_MID)
if (n <= TRTRI_SWITCHSIZE_LARGE)
*size_2 = n;
else if (n <= TRTRI_SWITCHSIZE_LARGE)
*size_2 = TRTRI_BLOCKSIZE;
else
*size_2 = n * TRTRI_BLOCKSIZE + 2 * ROCBLAS_TRMM_NB * ROCBLAS_TRMM_NB;
*size_2 *= sizeof(T)*batch_count;
Expand Down Expand Up @@ -267,4 +265,4 @@ rocblas_status rocsolver_trtri_template(rocblas_handle handle, const rocblas_fil
return rocblas_status_success;
}

#endif /* ROCLAPACK_GETRI_H */
#endif /* ROCLAPACK_TRTRI_H */

0 comments on commit dde7521

Please sign in to comment.