diff --git a/test/rocprim/test_device_radix_sort.hpp b/test/rocprim/test_device_radix_sort.hpp index 834575baa..6dbd9b143 100644 --- a/test/rocprim/test_device_radix_sort.hpp +++ b/test/rocprim/test_device_radix_sort.hpp @@ -716,6 +716,7 @@ inline void sort_keys_over_4g() key_type* d_keys_input_output{}; size_t key_type_storage_bytes = size * sizeof(key_type); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_keys_input_output, key_type_storage_bytes)); HIP_CHECK(hipMemcpy(d_keys_input_output, keys_input.data(), @@ -738,12 +739,12 @@ inline void sort_keys_over_4g() hipDeviceProp_t prop; HIP_CHECK(hipGetDeviceProperties(&prop, device_id)); - size_t total_storage_bytes = key_type_storage_bytes + temporary_storage_bytes; - if (total_storage_bytes > prop.totalGlobalMem) { + size_t total_storage_bytes = key_type_storage_bytes + temporary_storage_bytes; + if (total_storage_bytes > (static_cast(prop.totalGlobalMem * 0.90))) { HIP_CHECK(hipFree(d_keys_input_output)); GTEST_SKIP() << "Test case device memory requirement (" << total_storage_bytes << " bytes) exceeds available memory on current device (" << prop.totalGlobalMem << " bytes). Skipping test"; - } + } void* d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes));