Skip to content

Commit

Permalink
Enable hashing output for sygvdx (#854)
Browse files Browse the repository at this point in the history
* enable hash checking for sygvdx

* address feedback

* add support for getrf, potrf, syevx, sygvx

* amend comments, update sygvdx
  • Loading branch information
qjojo authored Nov 20, 2024
1 parent 7958727 commit 38c7a3a
Show file tree
Hide file tree
Showing 9 changed files with 251 additions and 69 deletions.
6 changes: 6 additions & 0 deletions clients/benchmarks/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ try
" This will additionally print the relative error of the computations.\n"
" ")

("hash",
value<rocblas_int>(&argus.hash_check)->default_value(0),
"Print hash of GPU results? 0 = No, 1 = Yes.\n"
" Meant for checking reproducibility of computations.\n"
" ")

// size options
("k",
value<rocblas_int>(),
Expand Down
9 changes: 9 additions & 0 deletions clients/common/containers/host_strided_batch_vector.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,15 @@ class host_strided_batch_vector
return ((bool)*this) ? hipSuccess : hipErrorOutOfMemory;
}

//!
//! @brief Get size of vector
//! @return number of elements
//!
size_t size() const
{
return this->m_nmemb;
}

private:
storage m_storage{storage::block};
int64_t m_n{};
Expand Down
44 changes: 31 additions & 13 deletions clients/common/lapack/testing_getf2_getrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -226,12 +226,18 @@ void getf2_getrf_getError(const rocblas_handle handle,
Uh& hInfo,
Ih& hInfoRes,
double* max_err,
const bool singular)
const bool singular,
size_t& hashA,
size_t& hashARes,
size_t& hashIpivRes)
{
// input data initialization
getf2_getrf_initData<true, true, T>(handle, m, n, dA, lda, stA, dIpiv, stP, dInfo, bc, hA,
hIpiv, singular);

// compute input hashes
hashA = deterministic_hash(hA, bc);

// execute computations
// GPU lapack
CHECK_ROCBLAS_ERROR(rocsolver_getf2_getrf(STRIDED, GETRF, handle, m, n, dA.data(), lda, stA,
Expand All @@ -247,6 +253,10 @@ void getf2_getrf_getError(const rocblas_handle handle,
: cpu_getf2(m, n, hA[b], lda, hIpiv[b], hInfo[b]);
}

// compute output hashes
hashARes = deterministic_hash(hARes, bc);
hashIpivRes = deterministic_hash(hIpivRes);

// expecting original matrix to be non-singular
// error is ||hA - hARes|| / ||hA|| (ideally ||LU - Lres Ures|| / ||LU||)
// (THIS DOES NOT ACCOUNT FOR NUMERICAL REPRODUCIBILITY ISSUES.
Expand Down Expand Up @@ -373,8 +383,8 @@ void testing_getf2_getrf(Arguments& argus)
I bc = argus.batch_count;
int hot_calls = argus.iters;

rocblas_stride stARes = (argus.unit_check || argus.norm_check) ? stA : 0;
rocblas_stride stPRes = (argus.unit_check || argus.norm_check) ? stP : 0;
rocblas_stride stARes = (argus.unit_check || argus.norm_check || argus.hash_check) ? stA : 0;
rocblas_stride stPRes = (argus.unit_check || argus.norm_check || argus.hash_check) ? stP : 0;

// check non-supported values
// N/A
Expand All @@ -383,9 +393,10 @@ void testing_getf2_getrf(Arguments& argus)
size_t size_A = size_t(lda) * n;
size_t size_P = size_t(min(m, n));
double max_error = 0, gpu_time_used = 0, cpu_time_used = 0;
size_t hashA = 0, hashARes = 0, hashIpivRes = 0;

size_t size_ARes = (argus.unit_check || argus.norm_check) ? size_A : 0;
size_t size_PRes = (argus.unit_check || argus.norm_check) ? size_P : 0;
size_t size_ARes = (argus.unit_check || argus.norm_check || argus.hash_check) ? size_A : 0;
size_t size_PRes = (argus.unit_check || argus.norm_check || argus.hash_check) ? size_P : 0;

// check invalid sizes
bool invalid_size = (m < 0 || n < 0 || lda < m || bc < 0);
Expand Down Expand Up @@ -460,10 +471,10 @@ void testing_getf2_getrf(Arguments& argus)
}

// check computations
if(argus.unit_check || argus.norm_check)
getf2_getrf_getError<STRIDED, GETRF, T>(handle, m, n, dA, lda, stA, dIpiv, stP, dInfo,
bc, hA, hARes, hIpiv, hIpivRes, hInfo, hInfoRes,
&max_error, argus.singular);
if(argus.unit_check || argus.norm_check || argus.hash_check)
getf2_getrf_getError<STRIDED, GETRF, T>(
handle, m, n, dA, lda, stA, dIpiv, stP, dInfo, bc, hA, hARes, hIpiv, hIpivRes,
hInfo, hInfoRes, &max_error, argus.singular, hashA, hashARes, hashIpivRes);

// collect performance data
if(argus.timing)
Expand Down Expand Up @@ -504,10 +515,10 @@ void testing_getf2_getrf(Arguments& argus)
}

// check computations
if(argus.unit_check || argus.norm_check)
getf2_getrf_getError<STRIDED, GETRF, T>(handle, m, n, dA, lda, stA, dIpiv, stP, dInfo,
bc, hA, hARes, hIpiv, hIpivRes, hInfo, hInfoRes,
&max_error, argus.singular);
if(argus.unit_check || argus.norm_check || argus.hash_check)
getf2_getrf_getError<STRIDED, GETRF, T>(
handle, m, n, dA, lda, stA, dIpiv, stP, dInfo, bc, hA, hARes, hIpiv, hIpivRes,
hInfo, hInfoRes, &max_error, argus.singular, hashA, hashARes, hashIpivRes);

// collect performance data
if(argus.timing)
Expand Down Expand Up @@ -555,6 +566,13 @@ void testing_getf2_getrf(Arguments& argus)
rocsolver_bench_output(cpu_time_used, gpu_time_used);
}
rocsolver_bench_endl();
if(argus.hash_check)
{
rocsolver_bench_output("hash(A)", "hash(ARes)", "hash(ipivRes)");
rocsolver_bench_output(ROCSOLVER_FORMAT_HASH(hashA), ROCSOLVER_FORMAT_HASH(hashARes),
ROCSOLVER_FORMAT_HASH(hashIpivRes));
rocsolver_bench_endl();
}
}
else
{
Expand Down
29 changes: 22 additions & 7 deletions clients/common/lapack/testing_potf2_potrf.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,19 +185,27 @@ void potf2_potrf_getError(const rocblas_handle handle,
Uh& hInfo,
Ih& hInfoRes,
double* max_err,
const bool singular)
const bool singular,
size_t& hashA,
size_t& hashARes)
{
// input data initialization
potf2_potrf_initData<true, true, T>(handle, uplo, n, dA, lda, stA, dInfo, bc, hA, hInfo,
singular);

// hash input
hashA = deterministic_hash(hA, bc);

// execute computations
// GPU lapack
CHECK_ROCBLAS_ERROR(rocsolver_potf2_potrf(STRIDED, POTRF, handle, uplo, n, dA.data(), lda, stA,
dInfo.data(), bc));
CHECK_HIP_ERROR(hARes.transfer_from(dA));
CHECK_HIP_ERROR(hInfoRes.transfer_from(dInfo));

// hash output
hashARes = deterministic_hash(hARes, bc);

// CPU lapack
for(I b = 0; b < bc; ++b)
{
Expand Down Expand Up @@ -321,7 +329,7 @@ void testing_potf2_potrf(Arguments& argus)
I bc = argus.batch_count;
rocblas_int hot_calls = argus.iters;

rocblas_stride stARes = (argus.unit_check || argus.norm_check) ? stA : 0;
rocblas_stride stARes = (argus.unit_check || argus.norm_check || argus.hash_check) ? stA : 0;

// check non-supported values
if(uplo != rocblas_fill_upper && uplo != rocblas_fill_lower)
Expand All @@ -344,8 +352,9 @@ void testing_potf2_potrf(Arguments& argus)
// determine sizes
size_t size_A = size_t(lda) * n;
double max_error = 0, gpu_time_used = 0, cpu_time_used = 0;
size_t hashA = 0, hashARes = 0;

size_t size_ARes = (argus.unit_check || argus.norm_check) ? size_A : 0;
size_t size_ARes = (argus.unit_check || argus.norm_check || argus.hash_check) ? size_A : 0;

// check invalid sizes
bool invalid_size = (n < 0 || lda < n || bc < 0);
Expand Down Expand Up @@ -414,10 +423,10 @@ void testing_potf2_potrf(Arguments& argus)
}

// check computations
if(argus.unit_check || argus.norm_check)
if(argus.unit_check || argus.norm_check || argus.hash_check)
potf2_potrf_getError<STRIDED, POTRF, T>(handle, uplo, n, dA, lda, stA, dInfo, bc, hA,
hARes, hInfo, hInfoRes, &max_error,
argus.singular);
argus.singular, hashA, hashARes);

// collect performance data
if(argus.timing)
Expand Down Expand Up @@ -452,10 +461,10 @@ void testing_potf2_potrf(Arguments& argus)
}

// check computations
if(argus.unit_check || argus.norm_check)
if(argus.unit_check || argus.norm_check || argus.hash_check)
potf2_potrf_getError<STRIDED, POTRF, T>(handle, uplo, n, dA, lda, stA, dInfo, bc, hA,
hARes, hInfo, hInfoRes, &max_error,
argus.singular);
argus.singular, hashA, hashARes);

// collect performance data
if(argus.timing)
Expand Down Expand Up @@ -502,6 +511,12 @@ void testing_potf2_potrf(Arguments& argus)
rocsolver_bench_output(cpu_time_used, gpu_time_used);
}
rocsolver_bench_endl();
if(argus.hash_check)
{
rocsolver_bench_output("hash(A)", "hash(ARes)");
rocsolver_bench_output(ROCSOLVER_FORMAT_HASH(hashA), ROCSOLVER_FORMAT_HASH(hashARes));
rocsolver_bench_endl();
}
}
else
{
Expand Down
48 changes: 34 additions & 14 deletions clients/common/lapack/testing_syevx_heevx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,10 @@ void syevx_heevx_getError(const rocblas_handle handle,
Ih& hIfailRes,
Ih& hinfo,
Ih& hinfoRes,
double* max_err)
double* max_err,
size_t& hashA,
size_t& hashW,
size_t& hashZ)
{
using HMat = HostMatrix<T, rocblas_int>;
using BDesc = typename HMat::BlockDescriptor;
Expand All @@ -331,6 +334,9 @@ void syevx_heevx_getError(const rocblas_handle handle,
// input data initialization
syevx_heevx_initData<true, true, T>(handle, evect, n, dA, lda, bc, hA, A);

// hash inputs
hashA = deterministic_hash(hA, bc);

// execute computations
// GPU lapack
CHECK_ROCBLAS_ERROR(rocsolver_syevx_heevx(
Expand All @@ -346,6 +352,11 @@ void syevx_heevx_getError(const rocblas_handle handle,
CHECK_HIP_ERROR(hIfailRes.transfer_from(dIfail));
}

// hash outputs
hashW = deterministic_hash(hWRes, bc);
if(evect == rocblas_evect_original)
hashZ = deterministic_hash(hZRes, bc);

// CPU lapack
// abstol = 0 ensures max accuracy in rocsolver; for lapack we should use 2*safemin
S atol = (abstol == 0) ? 2 * get_safemin<S>() : abstol;
Expand Down Expand Up @@ -619,11 +630,13 @@ void testing_syevx_heevx(Arguments& argus)
size_t size_W = n;
size_t size_Z = size_t(ldz) * n;
size_t size_ifail = n;
size_t size_WRes = (argus.unit_check || argus.norm_check) ? size_W : 0;
size_t size_ZRes = (argus.unit_check || argus.norm_check) ? size_Z : 0;
size_t size_ifailRes = (argus.unit_check || argus.norm_check) ? size_ifail : 0;
size_t size_WRes = (argus.unit_check || argus.norm_check || argus.hash_check) ? size_W : 0;
size_t size_ZRes = (argus.unit_check || argus.norm_check || argus.hash_check) ? size_Z : 0;
size_t size_ifailRes
= (argus.unit_check || argus.norm_check || argus.hash_check) ? size_ifail : 0;

double max_error = 0, gpu_time_used = 0, cpu_time_used = 0;
size_t hashA = 0, hashW = 0, hashZ = 0;

// check invalid sizes
bool invalid_size = (n < 0 || lda < n || (evect != rocblas_evect_none && ldz < n) || bc < 0
Expand Down Expand Up @@ -729,12 +742,12 @@ void testing_syevx_heevx(Arguments& argus)
}

// check computations
if(argus.unit_check || argus.norm_check)
if(argus.unit_check || argus.norm_check || argus.hash_check)
{
syevx_heevx_getError<STRIDED, T>(handle, evect, erange, uplo, n, dA, lda, stA, vl, vu,
il, iu, abstol, dNev, dW, stW, dZ, ldz, stZ, dIfail,
stF, dinfo, bc, hA, hNev, hNevRes, hW, hWres, hZ,
hZRes, hIfail, hIfailRes, hinfo, hinfoRes, &max_error);
syevx_heevx_getError<STRIDED, T>(
handle, evect, erange, uplo, n, dA, lda, stA, vl, vu, il, iu, abstol, dNev, dW, stW,
dZ, ldz, stZ, dIfail, stF, dinfo, bc, hA, hNev, hNevRes, hW, hWres, hZ, hZRes,
hIfail, hIfailRes, hinfo, hinfoRes, &max_error, hashA, hashW, hashZ);
}

// collect performance data
Expand Down Expand Up @@ -776,12 +789,12 @@ void testing_syevx_heevx(Arguments& argus)
}

// check computations
if(argus.unit_check || argus.norm_check)
if(argus.unit_check || argus.norm_check || argus.hash_check)
{
syevx_heevx_getError<STRIDED, T>(handle, evect, erange, uplo, n, dA, lda, stA, vl, vu,
il, iu, abstol, dNev, dW, stW, dZ, ldz, stZ, dIfail,
stF, dinfo, bc, hA, hNev, hNevRes, hW, hWres, hZ,
hZRes, hIfail, hIfailRes, hinfo, hinfoRes, &max_error);
syevx_heevx_getError<STRIDED, T>(
handle, evect, erange, uplo, n, dA, lda, stA, vl, vu, il, iu, abstol, dNev, dW, stW,
dZ, ldz, stZ, dIfail, stF, dinfo, bc, hA, hNev, hNevRes, hW, hWres, hZ, hZRes,
hIfail, hIfailRes, hinfo, hinfoRes, &max_error, hashA, hashW, hashZ);
}

// collect performance data
Expand Down Expand Up @@ -839,6 +852,13 @@ void testing_syevx_heevx(Arguments& argus)
rocsolver_bench_output(cpu_time_used, gpu_time_used);
}
rocsolver_bench_endl();
if(argus.hash_check)
{
rocsolver_bench_output("hash(A)", "hash(W)", "hash(Z)");
rocsolver_bench_output(ROCSOLVER_FORMAT_HASH(hashA), ROCSOLVER_FORMAT_HASH(hashW),
ROCSOLVER_FORMAT_HASH(hashZ));
rocsolver_bench_endl();
}
}
else
{
Expand Down
Loading

0 comments on commit 38c7a3a

Please sign in to comment.