Skip to content

Commit

Permalink
Merge pull request #763 from jmachado-amd/release/rocm-rel-6.2
Browse files Browse the repository at this point in the history
Two fixes for the Jacobi eigensolver
  • Loading branch information
mamaydeo authored Jul 22, 2024
2 parents 48ddb95 + d6c8e93 commit 7e1b068
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 42 deletions.
2 changes: 1 addition & 1 deletion library/src/lapack/roclapack_syevj_heevj.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ rocblas_status rocsolver_syevj_heevj_impl(rocblas_handle handle,
// memory workspace allocation
void *Acpy, *J, *norms, *top, *bottom, *completed;
rocblas_device_malloc mem(handle, size_Acpy, size_J, size_norms, size_top, size_bottom,
size_completed, size_norms);
size_completed);

if(!mem)
return rocblas_status_memory_error;
Expand Down
70 changes: 31 additions & 39 deletions library/src/lapack/roclapack_syevj_heevj.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1105,53 +1105,45 @@ template <typename T>
ROCSOLVER_KERNEL void
syevj_cycle_pairs(const rocblas_int half_blocks, rocblas_int* top, rocblas_int* bottom)
{
rocblas_int tix = hipThreadIdx_x;
rocblas_int i, j, k;
rocblas_int n = half_blocks - 1;

if(half_blocks <= hipBlockDim_x && tix < half_blocks)
auto cycle = [n = n](auto i) -> auto
{
if(tix == 0)
i = 0;
else if(tix == 1)
i = bottom[0];
else if(tix > 1)
i = top[tix - 1];
using I = decltype(i);
i = (i - 1) % (2 * n + 1) + 1;
I j{};

if(tix == half_blocks - 1)
j = top[half_blocks - 1];
if(i % 2 == 0)
{
j = i + 2;
if(j > 2 * n)
{
j = 2 * n + 1;
}
}
else
j = bottom[tix + 1];
__syncthreads();

top[tix] = i;
bottom[tix] = j;
}
else
{
// shared memory
extern __shared__ double lmem[];
rocblas_int* sh_top = reinterpret_cast<rocblas_int*>(lmem);
rocblas_int* sh_bottom = reinterpret_cast<rocblas_int*>(sh_top + half_blocks);

for(k = tix; k < half_blocks; k += hipBlockDim_x)
{
sh_top[k] = top[k];
sh_bottom[k] = bottom[k];
j = i - 2;
if(j < 1)
{
j = 2;
}
}
__syncthreads();

for(k = tix; k < half_blocks; k += hipBlockDim_x)
{
if(k == 1)
top[k] = sh_bottom[0];
else if(k > 1)
top[k] = sh_top[k - 1];
return j;
};

if(k == half_blocks - 1)
bottom[k] = sh_top[half_blocks - 1];
else
bottom[k] = sh_bottom[k + 1];
}
rocblas_int tidx = hipThreadIdx_x;
rocblas_int dimx = hipBlockDim_x;

if(tidx == 0)
{
bottom[0] = cycle(bottom[0]);
}
for(rocblas_int l = tidx + 1; l < half_blocks; l += dimx)
{
top[l] = cycle(top[l]);
bottom[l] = cycle(bottom[l]);
}
}

Expand Down
2 changes: 1 addition & 1 deletion library/src/lapack/roclapack_syevj_heevj_batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ rocblas_status rocsolver_syevj_heevj_batched_impl(rocblas_handle handle,
// memory workspace allocation
void *Acpy, *J, *norms, *top, *bottom, *completed;
rocblas_device_malloc mem(handle, size_Acpy, size_J, size_norms, size_top, size_bottom,
size_completed, size_norms);
size_completed);

if(!mem)
return rocblas_status_memory_error;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ rocblas_status rocsolver_syevj_heevj_strided_batched_impl(rocblas_handle handle,
// memory workspace allocation
void *Acpy, *J, *norms, *top, *bottom, *completed;
rocblas_device_malloc mem(handle, size_Acpy, size_J, size_norms, size_top, size_bottom,
size_completed, size_norms);
size_completed);

if(!mem)
return rocblas_status_memory_error;
Expand Down

0 comments on commit 7e1b068

Please sign in to comment.