Skip to content

Commit

Permalink
Potential improvement to set/restore_diag in GEQR2 (#826)
Browse files Browse the repository at this point in the history
* modify larfg with separate alpha/beta

* geqr2 restore diag collectively

* remove template param

* rename larfg general template
  • Loading branch information
AGonzales-amd authored Nov 7, 2024
1 parent d1be0ce commit a01a2cd
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 33 deletions.
86 changes: 75 additions & 11 deletions library/src/auxiliary/rocauxiliary_larfg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
ROCSOLVER_BEGIN_NAMESPACE

template <typename T, std::enable_if_t<!rocblas_is_complex<T>, int> = 0>
__device__ void run_set_taubeta(T* tau, T* norms, T* alpha)
__device__ void run_set_taubeta(T* tau, T* norms, T* alpha, T* beta)
{
const auto ignore_beta = (beta == nullptr);
if(norms[0] > 0)
{
T n = sqrt(norms[0] + alpha[0] * alpha[0]);
Expand All @@ -53,17 +54,32 @@ __device__ void run_set_taubeta(T* tau, T* norms, T* alpha)
tau[0] = (n - alpha[0]) / n;

// beta:
alpha[0] = n;
if(ignore_beta)
{
alpha[0] = n;
}
else
{
beta[0] = n;
alpha[0] = 1;
}
}
else
{
norms[0] = 1;
tau[0] = 0;

// beta:
if(!ignore_beta)
{
beta[0] = alpha[0];
alpha[0] = 1;
}
}
}

template <typename T, std::enable_if_t<rocblas_is_complex<T>, int> = 0>
__device__ void run_set_taubeta(T* tau, T* norms, T* alpha)
__device__ void run_set_taubeta(T* tau, T* norms, T* alpha, T* beta)
{
using S = decltype(std::real(T{}));
S r, rr, ri, ar, ai;
Expand All @@ -72,6 +88,7 @@ __device__ void run_set_taubeta(T* tau, T* norms, T* alpha)
ai = alpha[0].imag();
S m = ai * ai;

const auto ignore_beta = (beta == nullptr);
if(norms[0].real() > 0 || m > 0)
{
m += ar * ar;
Expand All @@ -92,30 +109,49 @@ __device__ void run_set_taubeta(T* tau, T* norms, T* alpha)
tau[0] = rocblas_complex_num<S>(rr, ri);

// beta:
alpha[0] = n;
if(ignore_beta)
{
alpha[0] = n;
}
else
{
beta[0] = n;
alpha[0] = 1;
}
}
else
{
norms[0] = 1;
tau[0] = 0;

// beta:
if(!ignore_beta)
{
beta[0] = alpha[0];
alpha[0] = 1;
}
}
}

template <typename T, typename I, typename U>
template <typename T, typename I, typename U, typename UB>
ROCSOLVER_KERNEL void set_taubeta(T* tauA,
const rocblas_stride strideP,
T* norms,
U alphaA,
const rocblas_stride shiftA,
const rocblas_stride strideA)
const rocblas_stride strideA,
UB betaA,
const rocblas_stride shiftb,
const rocblas_stride strideb)
{
I bid = hipBlockIdx_x;

// select batch instance
T* alpha = load_ptr_batch<T>(alphaA, bid, shiftA, strideA);
T* beta = betaA ? load_ptr_batch<T>(betaA, bid, shiftb, strideb) : nullptr;
T* tau = tauA + bid * strideP;

run_set_taubeta<T>(tau, norms + bid, alpha);
run_set_taubeta<T>(tau, norms + bid, alpha, beta);
}

template <typename T, typename I>
Expand Down Expand Up @@ -183,11 +219,14 @@ rocblas_status
return rocblas_status_continue;
}

template <typename T, typename I, typename U, bool COMPLEX = rocblas_is_complex<T>>
template <typename T, typename I, typename U, typename UB, bool COMPLEX = rocblas_is_complex<T>>
rocblas_status rocsolver_larfg_template(rocblas_handle handle,
const I n,
U alpha,
const rocblas_stride shifta,
UB beta,
const rocblas_stride shiftb,
const rocblas_stride strideb,
U x,
const rocblas_stride shiftx,
const I incx,
Expand All @@ -211,11 +250,17 @@ rocblas_status rocsolver_larfg_template(rocblas_handle handle,

// if n==1 return tau=0
dim3 gridReset(1, batch_count, 1);
dim3 setDiag(batch_count, 1, 1);
dim3 threads(1, 1, 1);
if(n == 1 && !COMPLEX)
{
ROCSOLVER_LAUNCH_KERNEL(reset_batch_info<T>, gridReset, threads, 0, stream, tau, strideP, 1,
0);
if(beta != nullptr)
{
ROCSOLVER_LAUNCH_KERNEL((set_diag<T>), setDiag, threads, 0, stream, beta, shiftb,
strideb, alpha, shifta, n, stridex, (I)1, true);
}
return rocblas_status_success;
}

Expand All @@ -229,8 +274,8 @@ rocblas_status rocsolver_larfg_template(rocblas_handle handle,
HIP_CHECK(hipGetDeviceProperties(&deviceProperties, device));
if(deviceProperties.warpSize >= 64)
{
return larfg_run_small(handle, n, alpha, shifta, stridex, x, shiftx, incx, stridex, tau,
strideP, batch_count);
return larfg_run_small(handle, n, alpha, shifta, stridex, beta, shiftb, strideb, x,
shiftx, incx, stridex, tau, strideP, batch_count);
}
}

Expand All @@ -246,7 +291,7 @@ rocblas_status rocsolver_larfg_template(rocblas_handle handle,
// set value of tau and beta and scalling factor for vector x
// alpha <- beta, norms <- scaling
ROCSOLVER_LAUNCH_KERNEL((set_taubeta<T, I>), dim3(batch_count), dim3(1), 0, stream, tau,
strideP, norms, alpha, shifta, stridex);
strideP, norms, alpha, shifta, stridex, beta, shiftb, strideb);

// compute vector v=x*norms
rocblasCall_scal<T>(handle, n - 1, norms, 1, x, shiftx, incx, stridex, batch_count);
Expand All @@ -255,4 +300,23 @@ rocblas_status rocsolver_larfg_template(rocblas_handle handle,
return rocblas_status_success;
}

template <typename T, typename I, typename U, bool COMPLEX = rocblas_is_complex<T>>
rocblas_status rocsolver_larfg_template(rocblas_handle handle,
const I n,
U alpha,
const rocblas_stride shifta,
U x,
const rocblas_stride shiftx,
const I incx,
const rocblas_stride stridex,
T* tau,
const rocblas_stride strideP,
const I batch_count,
T* work,
T* norms)
{
return rocsolver_larfg_template(handle, n, alpha, shifta, (T*)nullptr, 0, 0, x, shiftx, incx,
stridex, tau, strideP, batch_count, work, norms);
}

ROCSOLVER_END_NAMESPACE
5 changes: 4 additions & 1 deletion library/src/include/rocsolver_run_specialized_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,15 @@ rocblas_status larf_run_small(rocblas_handle handle,
const I batch_count);

// larfg
template <typename T, typename I, typename U>
template <typename T, typename I, typename U, typename UB>
rocblas_status larfg_run_small(rocblas_handle handle,
const I n,
U alpha,
const rocblas_stride shiftA,
const rocblas_stride strideA,
UB beta,
const rocblas_stride shiftB,
const rocblas_stride strideB,
U x,
const rocblas_stride shiftX,
const I incX,
Expand Down
25 changes: 11 additions & 14 deletions library/src/lapack/roclapack_geqr2.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ void rocsolver_geqr2_getMemorySize(const I m,
*size_Abyx_norms = std::max(s1, s2);

// size of array to store temporary diagonal values
*size_diag = sizeof(T) * batch_count;
*size_diag = sizeof(T) * std::min(m, n) * batch_count;
}

template <typename T, typename I, typename U>
Expand Down Expand Up @@ -131,18 +131,13 @@ rocblas_status rocsolver_geqr2_template(rocblas_handle handle,
for(I j = 0; j < dim; ++j)
{
// generate Householder reflector to work on column j
rocsolver_larfg_template(handle, m - j, A, shiftA + idx2D(j, j, lda), A,
shiftA + idx2D(std::min(j + 1, m - 1), j, lda), (I)1, strideA,
(ipiv + j), strideP, batch_count, (T*)work_workArr, Abyx_norms);
rocsolver_larfg_template<T>(handle, m - j, A, shiftA + idx2D(j, j, lda), diag, j, dim, A,
shiftA + idx2D(std::min(j + 1, m - 1), j, lda), (I)1, strideA,
(ipiv + j), strideP, batch_count, (T*)work_workArr, Abyx_norms);

// Apply Householder reflector to the rest of matrix from the left
if(j < n - 1)
{
// insert one in A(j,j) tobuild/apply the householder matrix
ROCSOLVER_LAUNCH_KERNEL((set_diag<T, I>), dim3(batch_count, 1, 1), dim3(1, 1, 1), 0,
stream, diag, 0, 1, A, shiftA + idx2D(j, j, lda), lda, strideA,
(I)1, true);

// conjugate tau
if(COMPLEX)
rocsolver_lacgv_template<T>(handle, (I)1, ipiv, j, (I)1, strideP, batch_count);
Expand All @@ -152,17 +147,19 @@ rocblas_status rocsolver_geqr2_template(rocblas_handle handle,
A, shiftA + idx2D(j, j + 1, lda), lda, strideA, batch_count,
scalars, Abyx_norms, (T**)work_workArr);

// restore original value of A(j,j)
ROCSOLVER_LAUNCH_KERNEL((restore_diag<T, I>), dim3(batch_count, 1, 1), dim3(1, 1, 1), 0,
stream, diag, 0, 1, A, shiftA + idx2D(j, j, lda), lda, strideA,
(I)1);

// restore tau
if(COMPLEX)
rocsolver_lacgv_template<T>(handle, (I)1, ipiv, j, (I)1, strideP, batch_count);
}
}

// restore diagonal values of A
constexpr int DIAG_NTHREADS = 64;
I blocks = (dim - 1) / DIAG_NTHREADS + 1;
ROCSOLVER_LAUNCH_KERNEL((restore_diag<T, I>), dim3(batch_count, blocks, 1),
dim3(1, DIAG_NTHREADS, 1), 0, stream, diag, 0, dim, A, shiftA, lda,
strideA, dim);

return rocblas_status_success;
}

Expand Down
23 changes: 16 additions & 7 deletions library/src/specialized/rocauxiliary_larfg_specialized_kernels.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,15 @@ ROCSOLVER_BEGIN_NAMESPACE
the library size.
*************************************************************/

template <typename T, typename I, typename U>
template <typename T, typename I, typename U, typename UB>
ROCSOLVER_KERNEL void __launch_bounds__(LARFG_SSKER_THREADS)
larfg_kernel_small(const I n,
U alpha,
const rocblas_stride shiftA,
const rocblas_stride strideA,
UB beta,
const rocblas_stride shiftB,
const rocblas_stride strideB,
U xx,
const rocblas_stride shiftX,
const I incX,
Expand All @@ -65,6 +68,8 @@ ROCSOLVER_KERNEL void __launch_bounds__(LARFG_SSKER_THREADS)
T* x = load_ptr_batch<T>(xx, bid, shiftX, strideX);
T* tau = load_ptr_batch<T>(tauA, bid, 0, strideP);

T* b = beta ? load_ptr_batch<T>(beta, bid, shiftB, strideB) : nullptr;

// shared variables
__shared__ T sval[LARFG_SSKER_THREADS];
__shared__ T sh_x[LARFG_SSKER_MAX_N];
Expand All @@ -78,7 +83,7 @@ ROCSOLVER_KERNEL void __launch_bounds__(LARFG_SSKER_THREADS)

// set tau, beta, and put scaling factor into sval[0]
if(tid == 0)
run_set_taubeta<T>(tau, sval, a);
run_set_taubeta<T>(tau, sval, a, b);
__syncthreads();

// scale x by scaling factor
Expand All @@ -90,12 +95,15 @@ ROCSOLVER_KERNEL void __launch_bounds__(LARFG_SSKER_THREADS)
Launchers of specialized kernels
*************************************************************/

template <typename T, typename I, typename U>
template <typename T, typename I, typename U, typename UB>
rocblas_status larfg_run_small(rocblas_handle handle,
const I n,
U alpha,
const rocblas_stride shiftA,
const rocblas_stride strideA,
UB beta,
const rocblas_stride shiftB,
const rocblas_stride strideB,
U x,
const rocblas_stride shiftX,
const I incX,
Expand All @@ -110,8 +118,8 @@ rocblas_status larfg_run_small(rocblas_handle handle,
hipStream_t stream;
rocblas_get_stream(handle, &stream);

ROCSOLVER_LAUNCH_KERNEL(larfg_kernel_small<T>, grid, block, 0, stream, n, alpha, shiftA,
strideA, x, shiftX, incX, strideX, tau, strideP);
ROCSOLVER_LAUNCH_KERNEL((larfg_kernel_small<T>), grid, block, 0, stream, n, alpha, shiftA,
strideA, beta, shiftB, strideB, x, shiftX, incX, strideX, tau, strideP);

return rocblas_status_success;
}
Expand All @@ -121,9 +129,10 @@ rocblas_status larfg_run_small(rocblas_handle handle,
*************************************************************/

#define INSTANTIATE_LARFG_SMALL(T, I, U) \
template rocblas_status larfg_run_small<T, I, U>( \
template rocblas_status larfg_run_small<T, I, U, T*>( \
rocblas_handle handle, const I n, U alpha, const rocblas_stride shiftA, \
const rocblas_stride strideA, U x, const rocblas_stride shiftX, const I incX, \
const rocblas_stride strideA, T* beta, const rocblas_stride shiftB, \
const rocblas_stride strideB, U x, const rocblas_stride shiftX, const I incX, \
const rocblas_stride strideX, T* tau, const rocblas_stride strideP, const I batch_count)

ROCSOLVER_END_NAMESPACE

0 comments on commit a01a2cd

Please sign in to comment.