From 5f9f13e09bd8dacb08f7feed12decbfb906c91f9 Mon Sep 17 00:00:00 2001 From: Wayne Franz Date: Tue, 21 Nov 2023 22:29:54 -0500 Subject: [PATCH] Add unit testing to verify that algorithms work with hipGraphs (#478) * Basic hipGraph tests * Add basic tests for graph creation, instantiation, and execution using: * stream capture * manual construction * hipGraph test for device_reduce algorithms * Added new unit tests for device_reduce, device_reduce_by_key algorithms to verify basic support for hipGraphs (no synchronous API functions are called within the algorithms). * Fixed up CMakeLists compile issue for tests in the test/hipgraph folder * Updated code documentation * Add hipGraph unit tests for device level algorithms * Added unit tests that run the following algorithms inside of a graph (in isolation): - device_adjacent_difference - device_binary_search - device_histogram - device_merge - device_merge_sort - device_partition - device_radix_sort - device_scan - device_segmented_reduce - device_segmented_scan - device_select - device_transform * Updated existing tests for: - device_reduce - device_reduce_by_key * Moved graph test helper functions to a separate file * Add hipGraph unit tests * Added remaining device level hipGraph unit tests * Note: currently, there are two device level algorithms that do no work with hipGraphs because they contain synchronization barriers. No hipGraph unit tests have been added for these algorithms: * device_run_length_encode * device_segmented_radix_sort * Added a functional integration test for hipGraphs, which runs several algorithms back-to-back within a graph. * Refactored test helper code to remove unnecessary parameter * Set hipgraph test pointers to nullptr * Set key_type device pointers to nullptr when they are declared, for safety. * Several minor fixes for hipGraph tests * Fixed up spelling error in comments * Moved call to hipGetLastError to a more appropriate position * Removed old commented test code * Minor fixes for hipgraph unit tests * Moved several synchronization barriers so they are now outside of graph capture blocks in the test_device_partition source * Changed several loop counters to unsigned type * Updatedpgraph cmake files - removed test/hipgraph directory's CMakeLists.txt * Additional test and bugfix for hipgraph tests * Removed syncrhonization barrier in test_device_scan * Added basic test to exercise atomic function within a hipgraph * Rebased and resolved merge conflicts --- test/CMakeLists.txt | 4 + test/hipgraph/test_hipgraph_algs.cpp | 229 ++++++++++ test/hipgraph/test_hipgraph_basic.cpp | 240 +++++++++++ .../test_device_adjacent_difference.cpp | 99 ++++- test/rocprim/test_device_binary_search.cpp | 83 +++- test/rocprim/test_device_histogram.cpp | 238 +++++++++-- test/rocprim/test_device_merge.cpp | 108 ++++- test/rocprim/test_device_merge_sort.cpp | 62 ++- test/rocprim/test_device_partition.cpp | 247 ++++++++++- test/rocprim/test_device_radix_sort.cpp.in | 7 +- test/rocprim/test_device_radix_sort.hpp | 143 ++++++- test/rocprim/test_device_reduce.cpp | 396 ++++++++++++----- test/rocprim/test_device_reduce_by_key.cpp | 101 ++++- test/rocprim/test_device_scan.cpp | 402 +++++++++++++++--- test/rocprim/test_device_segmented_reduce.cpp | 33 +- test/rocprim/test_device_segmented_scan.cpp | 111 ++++- test/rocprim/test_device_select.cpp | 244 ++++++++++- test/rocprim/test_device_transform.cpp | 90 +++- test/rocprim/test_utils.hpp | 1 + test/rocprim/test_utils_hipgraphs.hpp | 88 ++++ 20 files changed, 2629 insertions(+), 297 deletions(-) create mode 100644 test/hipgraph/test_hipgraph_algs.cpp create mode 100644 test/hipgraph/test_hipgraph_basic.cpp create mode 100644 test/rocprim/test_utils_hipgraphs.hpp diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 0c7017387..b7cc9b83d 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -126,6 +126,10 @@ add_hip_test("hip.async_copy" hip/test_hip_async_copy.cpp) # rocPRIM test add_subdirectory(rocprim) +# hipGraph tests +add_hip_test("hipgraph.basic" hipgraph/test_hipgraph_basic.cpp) +add_hip_test("hipgraph.algs" hipgraph/test_hipgraph_algs.cpp) + rocm_install( FILES "${INSTALL_TEST_FILE}" DESTINATION "${CMAKE_INSTALL_BINDIR}/${PROJECT_NAME}" diff --git a/test/hipgraph/test_hipgraph_algs.cpp b/test/hipgraph/test_hipgraph_algs.cpp new file mode 100644 index 000000000..a0811fd89 --- /dev/null +++ b/test/hipgraph/test_hipgraph_algs.cpp @@ -0,0 +1,229 @@ +// MIT License +// +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// required test headers +#include "../common_test_header.hpp" + +// required rocprim headers +#include +#include "../rocprim/test_seed.hpp" +#include "../rocprim/test_utils.hpp" + +// required STL headers +#include + +template +void generate_needles(const std::vector& input, std::vector& output, const size_t search_needle_size, const std::pair& bounds, const seed_type& seed_value) +{ + // Pick 50% of the needles from the input vector + std::vector indices = test_utils::get_random_data(search_needle_size / 2, 0, input.size() - 1, seed_value); + + // Do selection on the in-bounds indices and write the results to the output vector + std::transform(indices.begin(), indices.end(), output.begin(), [&input](const KeyType& index) + { + return input[index]; + }); + + // Generate the other 50% from outside the input vector + const KeyType max_val = std::get<1>(bounds); + std::vector out_of_bounds_vals = test_utils::get_random_data(search_needle_size - search_needle_size / 2, max_val, max_val * 2, seed_value); + + // Append the out-of-bounds values + for (size_t i = 0; i < out_of_bounds_vals.size(); i++) + output[indices.size() + i] = out_of_bounds_vals[i]; + + // Mix up the in-bounds and out-of-bounds values to make the test a bit more robust + std::random_shuffle(output.begin(), output.end()); +} + +template +void computeExpectedSortAndSearchResult(std::vector& sort_input, const std::vector& search_needles, std::vector& expected_search_output, BinaryFunction compare_op) +{ + // Sort + std::stable_sort(sort_input.begin(), sort_input.end(), compare_op); + + // Search + for (size_t i = 0; i < search_needles.size(); i++) + expected_search_output[i] = std::binary_search(sort_input.begin(), sort_input.end(), search_needles[i], compare_op); +} + +// This test creates a graph that performs a device-wide merge_sort followed by a device-wide binary_search. +// After the graph is created, it is launched multiple times, each time using different data. +TEST(TestHipGraphAlgs, SortAndSearch) +{ + // Test case params + using key_type = int; + using compare_fcn_type = typename ::rocprim::less; + compare_fcn_type compare_op; + const size_t sort_data_size = 4096; + const size_t search_needle_size = 100; + const size_t num_trials = 5; + std::pair bounds = std::make_pair(-10000, 10000); // generated data will fall in this range + const bool debug_synchronous = false; + + // Set the device + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + // Generate data on the host + const seed_type seed_value = seeds[random_seeds_count - 1]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + SCOPED_TRACE(testing::Message() << "with sort_data_size = " << sort_data_size); + SCOPED_TRACE(testing::Message() << "with search_needle_size = " << search_needle_size); + + // Allocate device buffers and copy data into them + key_type* d_sort_input = nullptr; + key_type* d_sort_output = nullptr; // also used as search_input + key_type* d_search_output = nullptr; + key_type* d_search_needles = nullptr; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_sort_input, sort_data_size * sizeof(key_type))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_sort_output, sort_data_size * sizeof(key_type))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_search_output, search_needle_size * sizeof(key_type))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_search_needles, search_needle_size * sizeof(key_type))); + + // Default stream does not support hipGraph stream capture, so create a non-blocking one + hipStream_t stream = 0; + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + + // Begin graph capture + hipGraph_t graph = test_utils::createGraphHelper(stream); + + // Get temporary storage size required for merge_sort. + // Note: doing this inside a graph doesn't gain us any benefit, + // since these calls run entirely on the host - however, it is + // important to validate that they work inside a graph capture block. + size_t sort_temp_storage_bytes = 0; + HIP_CHECK(rocprim::merge_sort(nullptr, + sort_temp_storage_bytes, + d_sort_input, + d_sort_output, + sort_data_size, + compare_op, + stream, + debug_synchronous + )); + + // Get size of temporary storage required for binary_search + size_t search_temp_storage_bytes = 0; + HIP_CHECK(rocprim::binary_search(nullptr, + search_temp_storage_bytes, + d_sort_output, + d_search_needles, + d_search_output, + sort_data_size, + search_needle_size, + compare_op, + stream, + debug_synchronous + )); + + // End graph capture (since we can't malloc the temp storage inside the graph) + // and execute the graph (to get the temp storage size) + hipGraphExec_t graph_instance; + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + + // Allocate the temp storage + // Note: a single store will be used for both the sort and search algorithms + size_t temp_storage_bytes = std::max(sort_temp_storage_bytes, search_temp_storage_bytes); + // temp_storage_size_bytes must be > 0 + ASSERT_GT(temp_storage_bytes, 0); + + void* d_temp_storage = nullptr; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + // Re-start graph capture + test_utils::resetGraphHelper(graph, graph_instance, stream); + + // Launch merge_sort + HIP_CHECK( + rocprim::merge_sort(d_temp_storage, + sort_temp_storage_bytes, + d_sort_input, + d_sort_output, + sort_data_size, + compare_op, + stream, + false + ) + ); + + // Launch binary_search + HIP_CHECK( + rocprim::binary_search(d_temp_storage, + search_temp_storage_bytes, + d_sort_output, + d_search_needles, + d_search_output, + sort_data_size, + search_needle_size, + compare_op, + stream, + false + ); + ); + + // End graph capture, but do not execute the graph yet. + graph_instance = test_utils::endCaptureGraphHelper(graph, stream); + + std::vector sort_input; + std::vector search_needles(search_needle_size); + std::vector expected_search_output(search_needle_size); + std::vector device_output(search_needle_size); + + // We'll launch the graph multiple times with different data. + for (size_t i = 0; i < num_trials; i++) + { + // Generate the test data + sort_input = test_utils::get_random_data(sort_data_size, std::get<0>(bounds), std::get<1>(bounds), seed_value); + generate_needles(sort_input, search_needles, search_needle_size, bounds, seed_value); + + // Compute the expected result on the host + computeExpectedSortAndSearchResult(sort_input, search_needles, expected_search_output, compare_op); + + // Copy input data to the device + HIP_CHECK(hipMemcpy(d_sort_input, sort_input.data(), sort_data_size * sizeof(key_type), hipMemcpyHostToDevice)); + HIP_CHECK(hipMemcpy(d_search_needles, search_needles.data(), search_needle_size * sizeof(key_type), hipMemcpyHostToDevice)); + + // Launch the graph + test_utils::launchGraphHelper(graph_instance, stream, true); + + // Copy output back to host + HIP_CHECK(hipMemcpy(device_output.data(), d_search_output, search_needle_size * sizeof(key_type), hipMemcpyDeviceToHost)); + + // Validate the results + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(device_output, expected_search_output)); + } + + // Clean up + HIP_CHECK(hipFree(d_sort_input)); + HIP_CHECK(hipFree(d_sort_output)); + HIP_CHECK(hipFree(d_search_output)); + HIP_CHECK(hipFree(d_search_needles)); + HIP_CHECK(hipFree(d_temp_storage)); + + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); +} + + diff --git a/test/hipgraph/test_hipgraph_basic.cpp b/test/hipgraph/test_hipgraph_basic.cpp new file mode 100644 index 000000000..12d2c8739 --- /dev/null +++ b/test/hipgraph/test_hipgraph_basic.cpp @@ -0,0 +1,240 @@ +// MIT License +// +// Copyright (c) 2017-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. + +// required test headers +#include "../common_test_header.hpp" + +// required rocprim headers +#include +#include "../rocprim/test_seed.hpp" +#include "../rocprim/test_utils.hpp" + +// Basic test functions that can be used to check if HIP API functions +// work inside graphs. +// To test an API call, you can: +// - call it inside the graph stream capture zone of testStreamCapture() +// - add the corresponding type of graph node in the testManualConstruction() function +// +// HIP API functions that do not currently work: +// - hipMallocAsync + +// Simple test kernel that increments a value using a single thread. +__global__ __launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) void increment(int* data) +{ + const int gid = blockIdx.x * blockDim.x + threadIdx.x; + if (!gid) + data[gid]++; +} + +// Another simple kernel that can be used to test atomics inside a graph. +__global__ __launch_bounds__(ROCPRIM_DEFAULT_MAX_BLOCK_SIZE) void atomicIncrement(int* data) +{ + atomicAdd(data, 1); +} + +void testStreamCapture() +{ + // The default stream does not support HipGraph stream capture, so create our own. + hipStream_t stream; + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + + // Allocate a counter variable on the device and set it to 0. + // We will use this to record the number of times the graph is launched. + int* d_data = nullptr; + int h_data = 0; + + // Create a new graph + hipGraph_t graph; + HIP_CHECK(hipGraphCreate(&graph, 0)); + + // Note: currently, calls to hipMallocAsync do not work inside the stream capture section + HIP_CHECK(hipMallocAsync(&d_data, sizeof(int), stream)); + + // ** Begin stream capture ** + HIP_CHECK(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); + + // Transfer the host value + HIP_CHECK(hipMemcpyAsync(d_data, &h_data, sizeof(int), hipMemcpyHostToDevice, stream)); + + // Launch kernel + hipLaunchKernelGGL(increment, dim3(1), dim3(1), 0, stream, d_data); + + // Transfer result back to host + HIP_CHECK(hipMemcpyAsync(&h_data, d_data, sizeof(int), hipMemcpyDeviceToHost, stream)); + + // ** End stream capture ** + HIP_CHECK(hipStreamEndCapture(stream, &graph)); + + // Instantiate the graph + hipGraphExec_t instance; + HIP_CHECK(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0)); + + // Launch it + const int num_launches = 3; + for (int i = 0; i < num_launches; i++) + { + HIP_CHECK(hipGraphLaunch(instance, stream)); + } + HIP_CHECK(hipStreamSynchronize(stream)); + + // Counter value should match the number of graph launches + ASSERT_EQ(h_data, num_launches); + + // Clean up + HIP_CHECK(hipGraphDestroy(graph)); + HIP_CHECK(hipGraphExecDestroy(instance)); + HIP_CHECK(hipFree(d_data)); + HIP_CHECK(hipStreamDestroy(stream)); +} + +void testManualConstruction() +{ + // The default stream does not support HipGraph stream capture, so create our own. + hipStream_t stream; + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + + // Allocate a counter variable on the device and set it to 0. + // We will use this to record the number of times the graph is launched. + int* d_data = nullptr; + int h_data = 0; + HIP_CHECK(hipMallocAsync(&d_data, sizeof(int), stream)); + + // Create a new graph + hipGraph_t graph; + HIP_CHECK(hipGraphCreate(&graph, 0)); + + // Transfer the counter value from host to device using a graph memcpy node + hipGraphNode_t hostToDevMemcpyNode; + HIP_CHECK(hipGraphAddMemcpyNode1D(&hostToDevMemcpyNode, graph, nullptr, 0, d_data, &h_data, sizeof(int), hipMemcpyHostToDevice)); + + // Launch the kernel + hipGraphNode_t kernelNode; + void* kernelArgs[1] = {(void*) &d_data}; + hipKernelNodeParams kernelNodeParams{}; + kernelNodeParams.func = (void*) increment; + kernelNodeParams.gridDim = dim3(1); + kernelNodeParams.blockDim = dim3(1); + kernelNodeParams.sharedMemBytes = 0; + kernelNodeParams.kernelParams = (void**) (kernelArgs); + kernelNodeParams.extra = nullptr; + + // Add the kernel node to the graph, listing the memcpyNode as a dependency + HIP_CHECK(hipGraphAddKernelNode(&kernelNode, graph, &hostToDevMemcpyNode, 1, &kernelNodeParams)); + + // Transfer result back to the device + hipGraphNode_t devToHostMemcpyNode; + HIP_CHECK(hipGraphAddMemcpyNode1D(&devToHostMemcpyNode, graph, &kernelNode, 1, &h_data, d_data, sizeof(int), hipMemcpyDeviceToHost)); + + // Instantiate the graph + hipGraphExec_t instance; + HIP_CHECK(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0)); + + // Launch it + const int num_launches = 3; + for (int i = 0; i < num_launches; i++) + { + HIP_CHECK(hipGraphLaunch(instance, stream)); + } + HIP_CHECK(hipStreamSynchronize(stream)); + + // The counter value should match the number of times we launched the graph + ASSERT_EQ(h_data, num_launches); + + // Clean up + HIP_CHECK(hipGraphDestroy(graph)); + HIP_CHECK(hipGraphExecDestroy(instance)); + HIP_CHECK(hipFree(d_data)); + HIP_CHECK(hipStreamDestroy(stream)); +} + +void testStreamCaptureWithAtomics() +{ + // The default stream does not support HipGraph stream capture, so create our own. + hipStream_t stream; + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + + // Allocate a counter variable on the device. + // We will have each thread atomically increment it. + int* d_data = nullptr; + int h_data = 0; + const int num_blocks = 2; + const int num_threads = 33; + + // Create a new graph + hipGraph_t graph; + HIP_CHECK(hipGraphCreate(&graph, 0)); + + // Note: currently, calls to hipMallocAsync do not work inside the stream capture section + HIP_CHECK(hipMallocAsync(&d_data, sizeof(int), stream)); + + // ** Begin stream capture ** + HIP_CHECK(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); + + // Transfer the host value + HIP_CHECK(hipMemcpyAsync(d_data, &h_data, sizeof(int), hipMemcpyHostToDevice, stream)); + + // Launch kernel + hipLaunchKernelGGL(atomicIncrement, dim3(num_blocks), dim3(num_threads), 0, stream, d_data); + + // Transfer result back to host + HIP_CHECK(hipMemcpyAsync(&h_data, d_data, sizeof(int), hipMemcpyDeviceToHost, stream)); + + // ** End stream capture ** + HIP_CHECK(hipStreamEndCapture(stream, &graph)); + + // Instantiate the graph + hipGraphExec_t instance; + HIP_CHECK(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0)); + + // Launch it + const int num_launches = 3; + for (int i = 0; i < num_launches; i++) + { + HIP_CHECK(hipGraphLaunch(instance, stream)); + } + HIP_CHECK(hipStreamSynchronize(stream)); + + // Counter value should match the number of graph launches multiplied by + // the number of threads that were launched. + ASSERT_EQ(h_data, num_launches * num_blocks * num_threads); + + // Clean up + HIP_CHECK(hipGraphDestroy(graph)); + HIP_CHECK(hipGraphExecDestroy(instance)); + HIP_CHECK(hipFree(d_data)); + HIP_CHECK(hipStreamDestroy(stream)); +} + +TEST(TestHipGraphBasic, CaptureFromStream) +{ + testStreamCapture(); +} + +TEST(TestHipGraphBasic, ManualConstruction) +{ + testManualConstruction(); +} + +TEST(TestHipGraphBasic, StreamCaptureAtomics) +{ + testStreamCaptureWithAtomics(); +} diff --git a/test/rocprim/test_device_adjacent_difference.cpp b/test/rocprim/test_device_adjacent_difference.cpp index 784b7b950..82b2f4ead 100644 --- a/test/rocprim/test_device_adjacent_difference.cpp +++ b/test/rocprim/test_device_adjacent_difference.cpp @@ -129,8 +129,8 @@ template - + class Config = rocprim::default_config, + bool UseGraphs = false> struct DeviceAdjacentDifferenceParams { using input_type = InputType; @@ -139,6 +139,7 @@ struct DeviceAdjacentDifferenceParams static constexpr bool in_place = InPlace; static constexpr bool use_identity_iterator = UseIdentityIterator; using config = Config; + static constexpr bool use_graphs = UseGraphs; }; template @@ -152,6 +153,7 @@ class RocprimDeviceAdjacentDifferenceTests : public ::testing::Test static constexpr bool use_identity_iterator = Params::use_identity_iterator; static constexpr bool debug_synchronous = false; using config = typename Params::config; + static constexpr bool use_graphs = Params::use_graphs; }; using custom_double2 = test_utils::custom_test_type; @@ -182,7 +184,8 @@ using RocprimDeviceAdjacentDifferenceTestsParams = ::testing::Types< // Tests for different size_limits DeviceAdjacentDifferenceParams>, DeviceAdjacentDifferenceParams>, - DeviceAdjacentDifferenceParams>>; + DeviceAdjacentDifferenceParams>, + DeviceAdjacentDifferenceParams>; TYPED_TEST_SUITE(RocprimDeviceAdjacentDifferenceTests, RocprimDeviceAdjacentDifferenceTestsParams); @@ -210,8 +213,13 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) for(auto size : test_utils::get_sizes(seed_value)) { - static constexpr hipStream_t stream = 0; // default - + hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + SCOPED_TRACE(testing::Message() << "with size = " << size); // Generate data @@ -241,6 +249,11 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) const auto output_it = test_utils::wrap_in_identity_iterator(d_output); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // Allocate temporary storage std::size_t temp_storage_size; void* d_temp_storage = nullptr; @@ -253,12 +266,18 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) size, rocprim::minus<> {}, stream, - debug_synchronous)); + TestFixture::debug_synchronous)); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temp_storage_size, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size)); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK(dispatch_adjacent_difference(left_tag, in_place_tag, @@ -269,9 +288,12 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) size, rocprim::minus<> {}, stream, - debug_synchronous)); + TestFixture::debug_synchronous)); HIP_CHECK(hipGetLastError()); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // Copy output to host HIP_CHECK( hipMemcpy(output.data(), @@ -291,16 +313,23 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceTests, AdjacentDifference) hipFree(d_output); } hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } } } // Params for tests -template +template struct DeviceAdjacentDifferenceLargeParams { - static constexpr bool left = Left; - static constexpr bool in_place = InPlace; + static constexpr bool left = Left; + static constexpr bool in_place = InPlace; + static constexpr bool use_graphs = UseGraphs; }; template @@ -310,6 +339,7 @@ class RocprimDeviceAdjacentDifferenceLargeTests : public ::testing::Test static constexpr bool left = Params::left; static constexpr bool in_place = Params::in_place; static constexpr bool debug_synchronous = false; + static constexpr bool use_graphs = Params::use_graphs; }; template @@ -422,7 +452,8 @@ class check_output_iterator using RocprimDeviceAdjacentDifferenceLargeTestsParams = ::testing::Types, - DeviceAdjacentDifferenceLargeParams>; + DeviceAdjacentDifferenceLargeParams, + DeviceAdjacentDifferenceLargeParams>; TYPED_TEST_SUITE(RocprimDeviceAdjacentDifferenceLargeTests, RocprimDeviceAdjacentDifferenceLargeTestsParams); @@ -430,13 +461,26 @@ TYPED_TEST_SUITE(RocprimDeviceAdjacentDifferenceLargeTests, TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) { const int device_id = test_common_utils::obtain_device_from_ctest(); + + if (TestFixture::use_graphs) + { + // Skip this test on gfx1030 on Windows, since check_output_iterator does not appear to work there. + hipDeviceProp_t props; + HIP_CHECK(hipGetDeviceProperties(&props, device_id)); + std::string deviceName = std::string(props.gcnArchName); + if(deviceName.rfind("gfx1030", 0) == 0) + { + // This is a gfx1030 device, so skip this test + GTEST_SKIP() << "Temporarily skipping test on Windows for on gfx1030"; + } + } + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); using T = size_t; static constexpr bool is_left = TestFixture::left; static constexpr bool is_in_place = TestFixture::in_place; - const bool debug_synchronous = TestFixture::debug_synchronous; static constexpr unsigned int sampling_rate = 10000; using OutputIterator = check_output_iterator; using flag_type = OutputIterator::flag_type; @@ -444,7 +488,12 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) SCOPED_TRACE(testing::Message() << "is_left = " << is_left << ", is_in_place = " << is_in_place); - static constexpr hipStream_t stream = 0; // default + hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for(std::size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -476,6 +525,11 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) static constexpr auto left_tag = rocprim::detail::bool_constant{}; static constexpr auto in_place_tag = rocprim::detail::bool_constant{}; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // Allocate temporary storage std::size_t temp_storage_size; void* d_temp_storage = nullptr; @@ -488,12 +542,18 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) size, op, stream, - debug_synchronous)); + TestFixture::debug_synchronous)); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temp_storage_size, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size)); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK(dispatch_adjacent_difference(left_tag, in_place_tag, @@ -504,7 +564,10 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) size, op, stream, - debug_synchronous)); + TestFixture::debug_synchronous)); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); // Copy output to host flag_type incorrect_flag; @@ -521,6 +584,12 @@ TYPED_TEST(RocprimDeviceAdjacentDifferenceLargeTests, LargeIndices) hipFree(d_temp_storage); hipFree(d_incorrect_flag); hipFree(d_counter); + + if (TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + + if (TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } diff --git a/test/rocprim/test_device_binary_search.cpp b/test/rocprim/test_device_binary_search.cpp index e7370ab61..332be50f8 100644 --- a/test/rocprim/test_device_binary_search.cpp +++ b/test/rocprim/test_device_binary_search.cpp @@ -33,7 +33,8 @@ template, - class Config = rocprim::default_config> + class Config = rocprim::default_config, + bool UseGraphs = false> struct params { using haystack_type = Haystack; @@ -41,6 +42,7 @@ struct params using output_type = Output; using compare_op_type = CompareFunction; using config = Config; + static constexpr bool use_graphs = UseGraphs; }; template @@ -69,7 +71,8 @@ typedef ::testing::Types< rocprim::less, use_custom_config>, params, - params>> + params>, + params, rocprim::default_config, true>> Params; TYPED_TEST_SUITE(RocprimDeviceBinarySearch, Params); @@ -90,6 +93,11 @@ TYPED_TEST(RocprimDeviceBinarySearch, LowerBound) typename TestFixture::params::config>; hipStream_t stream = 0; + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const bool debug_synchronous = false; @@ -149,6 +157,11 @@ TYPED_TEST(RocprimDeviceBinarySearch, LowerBound) haystack.begin(); } + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + void * d_temporary_storage = nullptr; size_t temporary_storage_bytes; HIP_CHECK(rocprim::lower_bound(d_temporary_storage, @@ -162,10 +175,16 @@ TYPED_TEST(RocprimDeviceBinarySearch, LowerBound) stream, debug_synchronous)); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::lower_bound(d_temporary_storage, temporary_storage_bytes, d_haystack, @@ -177,6 +196,9 @@ TYPED_TEST(RocprimDeviceBinarySearch, LowerBound) stream, debug_synchronous)); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + std::vector output(needles_size); HIP_CHECK( hipMemcpy( @@ -191,11 +213,15 @@ TYPED_TEST(RocprimDeviceBinarySearch, LowerBound) HIP_CHECK(hipFree(d_needles)); HIP_CHECK(hipFree(d_output)); + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); } } - + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) @@ -214,6 +240,11 @@ TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) typename TestFixture::params::config>; hipStream_t stream = 0; + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const bool debug_synchronous = false; @@ -272,6 +303,11 @@ TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) haystack.begin(); } + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + void * d_temporary_storage = nullptr; size_t temporary_storage_bytes; HIP_CHECK(rocprim::upper_bound(d_temporary_storage, @@ -284,11 +320,16 @@ TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) compare_op, stream, debug_synchronous)); - + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::upper_bound(d_temporary_storage, temporary_storage_bytes, d_haystack, @@ -300,6 +341,9 @@ TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) stream, debug_synchronous)); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + std::vector output(needles_size); HIP_CHECK( hipMemcpy( @@ -314,11 +358,15 @@ TYPED_TEST(RocprimDeviceBinarySearch, UpperBound) HIP_CHECK(hipFree(d_needles)); HIP_CHECK(hipFree(d_output)); + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); } } - + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) @@ -337,6 +385,11 @@ TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) typename TestFixture::params::config>; hipStream_t stream = 0; + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const bool debug_synchronous = false; @@ -394,6 +447,11 @@ TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) expected[i] = std::binary_search(haystack.begin(), haystack.end(), needles[i], compare_op); } + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + void * d_temporary_storage = nullptr; size_t temporary_storage_bytes; HIP_CHECK(rocprim::binary_search(d_temporary_storage, @@ -407,10 +465,16 @@ TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) stream, debug_synchronous)); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::binary_search(d_temporary_storage, temporary_storage_bytes, d_haystack, @@ -422,6 +486,9 @@ TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) stream, debug_synchronous)); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + std::vector output(needles_size); HIP_CHECK( hipMemcpy( @@ -436,7 +503,13 @@ TYPED_TEST(RocprimDeviceBinarySearch, BinarySearch) HIP_CHECK(hipFree(d_needles)); HIP_CHECK(hipFree(d_output)); + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(output, expected)); } } + + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } diff --git a/test/rocprim/test_device_histogram.cpp b/test/rocprim/test_device_histogram.cpp index 20c9c92c1..1c9f38131 100644 --- a/test/rocprim/test_device_histogram.cpp +++ b/test/rocprim/test_device_histogram.cpp @@ -106,7 +106,8 @@ template + class Config = rocprim::default_config, + bool UseGraphs = false> struct params1 { using sample_type = SampleType; @@ -116,6 +117,7 @@ struct params1 using level_type = LevelType; using counter_type = CounterType; using config = Config; + static constexpr bool use_graphs = UseGraphs; }; template @@ -135,29 +137,63 @@ typedef ::testing::Types, params1, params1, - params1> + params1, + params1> Params1; TYPED_TEST_SUITE(RocprimDeviceHistogramEven, Params1); -TEST(RocprimDeviceHistogramEven, IncorrectInput) +template +void testHistogramEvenIncorrectInput() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); + hipStream_t stream = 0; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (UseGraphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + graph = test_utils::createGraphHelper(stream); + } + size_t temporary_storage_bytes = 0; int * d_input = nullptr; int * d_histogram = nullptr; + + hipError_t result = rocprim::histogram_even( + nullptr, temporary_storage_bytes, + d_input, 123, + d_histogram, + 1, 1, 2, stream + ); + + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_EQ( - rocprim::histogram_even( - nullptr, temporary_storage_bytes, - d_input, 123, - d_histogram, - 1, 1, 2 - ), - hipErrorInvalidValue + result, + hipErrorInvalidValue ); + + if (UseGraphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } +} + +TEST(RocprimDeviceHistogramEven, IncorrectInput) +{ + testHistogramEvenIncorrectInput(); +} + +TEST(RocprimDeviceHistogramEven, IncorrectInputWithGraphs) +{ + testHistogramEvenIncorrectInput(); } TYPED_TEST(RocprimDeviceHistogramEven, Even) @@ -174,6 +210,11 @@ TYPED_TEST(RocprimDeviceHistogramEven, Even) constexpr level_type upper_level = TestFixture::params::upper_level; hipStream_t stream = 0; + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const bool debug_synchronous = false; @@ -231,6 +272,11 @@ TYPED_TEST(RocprimDeviceHistogramEven, Even) using config = typename TestFixture::params::config; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temporary_storage_bytes = 0; if(rows == 1) { @@ -257,11 +303,17 @@ TYPED_TEST(RocprimDeviceHistogramEven, Even) ); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0U); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + if(rows == 1) { HIP_CHECK( @@ -287,6 +339,9 @@ TYPED_TEST(RocprimDeviceHistogramEven, Even) ); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + std::vector histogram(bins); HIP_CHECK( hipMemcpy( @@ -304,9 +359,14 @@ TYPED_TEST(RocprimDeviceHistogramEven, Even) { ASSERT_EQ(histogram[i], histogram_expected[i]); } - } + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + } } + + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } template + class Config = rocprim::default_config, + bool UseGraphs = false> struct params2 { using sample_type = SampleType; @@ -327,6 +388,7 @@ struct params2 using level_type = LevelType; using counter_type = CounterType; using config = Config; + static constexpr bool use_graphs = UseGraphs; }; template @@ -345,32 +407,67 @@ typedef ::testing::Types< params2, params2, - params2> + params2, + params2> Params2; TYPED_TEST_SUITE(RocprimDeviceHistogramRange, Params2); -TEST(RocprimDeviceHistogramRange, IncorrectInput) +template +void testHistogramRangeIncorrectInput() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); + hipStream_t stream = 0; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (UseGraphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + graph = test_utils::createGraphHelper(stream); + } + size_t temporary_storage_bytes = 0; int * d_input = nullptr; int * d_histogram = nullptr; int * d_levels = nullptr; + + hipError_t result = rocprim::histogram_range( + nullptr, temporary_storage_bytes, + d_input, 123, + d_histogram, + 1, d_levels, stream + ); + + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_EQ( - rocprim::histogram_range( - nullptr, temporary_storage_bytes, - d_input, 123, - d_histogram, - 1, d_levels - ), - hipErrorInvalidValue - ); + result, + hipErrorInvalidValue + ); + + if (UseGraphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } +} + +TEST(RocprimDeviceHistogramRange, RangeIncorrectInput) +{ + testHistogramRangeIncorrectInput(); } +TEST(RocprimDeviceHistogramRange, RangeIncorrectInputWithGraphs) +{ + testHistogramRangeIncorrectInput(); +} + + TYPED_TEST(RocprimDeviceHistogramRange, Range) { int device_id = test_common_utils::obtain_device_from_ctest(); @@ -382,7 +479,12 @@ TYPED_TEST(RocprimDeviceHistogramRange, Range) using level_type = typename TestFixture::params::level_type; constexpr unsigned int bins = TestFixture::params::bins; - hipStream_t stream = 0; + hipStream_t stream = 0; // default + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const bool debug_synchronous = false; @@ -470,6 +572,11 @@ TYPED_TEST(RocprimDeviceHistogramRange, Range) using config = typename TestFixture::params::config; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temporary_storage_bytes = 0; if(rows == 1) { @@ -498,11 +605,17 @@ TYPED_TEST(RocprimDeviceHistogramRange, Range) debug_synchronous)); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0U); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + if(rows == 1) { HIP_CHECK(rocprim::histogram_range(d_temporary_storage, @@ -530,6 +643,9 @@ TYPED_TEST(RocprimDeviceHistogramRange, Range) debug_synchronous)); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + std::vector histogram(bins); HIP_CHECK( hipMemcpy( @@ -548,12 +664,17 @@ TYPED_TEST(RocprimDeviceHistogramRange, Range) { ASSERT_EQ(histogram[i], histogram_expected[i]); } - } - + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + } } + + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } + template + class Config = rocprim::default_config, + bool UseGraphs = false> struct params3 { using sample_type = SampleType; @@ -574,6 +696,7 @@ struct params3 using level_type = LevelType; using counter_type = CounterType; using config = Config; + static constexpr bool use_graphs = UseGraphs; }; template @@ -595,7 +718,8 @@ typedef ::testing::Types< params3, params3, - params3> + params3, + params3> Params3; TYPED_TEST_SUITE(RocprimDeviceHistogramMultiEven, Params3); @@ -629,6 +753,11 @@ TYPED_TEST(RocprimDeviceHistogramMultiEven, MultiEven) } hipStream_t stream = 0; + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const bool debug_synchronous = false; @@ -727,6 +856,11 @@ TYPED_TEST(RocprimDeviceHistogramMultiEven, MultiEven) using config = typename TestFixture::params::config; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temporary_storage_bytes = 0; if(rows == 1) { @@ -759,11 +893,17 @@ TYPED_TEST(RocprimDeviceHistogramMultiEven, MultiEven) debug_synchronous))); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0U); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + if(rows == 1) { HIP_CHECK((rocprim::multi_histogram_even( @@ -795,6 +935,9 @@ TYPED_TEST(RocprimDeviceHistogramMultiEven, MultiEven) debug_synchronous))); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + std::vector histogram[active_channels]; for(unsigned int channel = 0; channel < active_channels; channel++) { @@ -821,9 +964,14 @@ TYPED_TEST(RocprimDeviceHistogramMultiEven, MultiEven) ASSERT_EQ(histogram[channel][i], histogram_expected[channel][i]); } } - } + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + } } + + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } template + class Config = rocprim::default_config, + bool UseGraphs = false> struct params4 { using sample_type = SampleType; @@ -848,6 +997,7 @@ struct params4 using level_type = LevelType; using counter_type = CounterType; using config = Config; + static constexpr bool use_graphs = UseGraphs; }; template @@ -866,7 +1016,8 @@ typedef ::testing::Types< params4, params4, - params4> + params4, + params4> Params4; TYPED_TEST_SUITE(RocprimDeviceHistogramMultiRange, Params4); @@ -884,6 +1035,11 @@ TYPED_TEST(RocprimDeviceHistogramMultiRange, MultiRange) constexpr unsigned int active_channels = TestFixture::params::active_channels; hipStream_t stream = 0; + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const bool debug_synchronous = false; @@ -1020,6 +1176,11 @@ TYPED_TEST(RocprimDeviceHistogramMultiRange, MultiRange) using config = typename TestFixture::params::config; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temporary_storage_bytes = 0; if(rows == 1) { @@ -1046,11 +1207,17 @@ TYPED_TEST(RocprimDeviceHistogramMultiRange, MultiRange) )); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0U); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + if(rows == 1) { HIP_CHECK(( @@ -1076,6 +1243,9 @@ TYPED_TEST(RocprimDeviceHistogramMultiRange, MultiRange) )); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + std::vector histogram[active_channels]; for(unsigned int channel = 0; channel < active_channels; channel++) { @@ -1094,6 +1264,9 @@ TYPED_TEST(RocprimDeviceHistogramMultiRange, MultiRange) HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_input)); + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + for(unsigned int channel = 0; channel < active_channels; channel++) { SCOPED_TRACE(testing::Message() << "with channel = " << channel); @@ -1104,7 +1277,8 @@ TYPED_TEST(RocprimDeviceHistogramMultiRange, MultiRange) } } } - - } + + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } diff --git a/test/rocprim/test_device_merge.cpp b/test/rocprim/test_device_merge.cpp index cb2c1acc0..2e608d500 100644 --- a/test/rocprim/test_device_merge.cpp +++ b/test/rocprim/test_device_merge.cpp @@ -42,13 +42,15 @@ template< class KeyType, class ValueType, - class CompareOp = ::rocprim::less + class CompareOp = ::rocprim::less, + bool UseGraphs = false > struct DeviceMergeParams { using key_type = KeyType; using value_type = ValueType; using compare_op_type = CompareOp; + static constexpr bool use_graphs = UseGraphs; }; template @@ -59,6 +61,7 @@ class RocprimDeviceMergeTests : public ::testing::Test using value_type = typename Params::value_type; using compare_op_type = typename Params::compare_op_type; const bool debug_synchronous = false; + static constexpr bool use_graphs = Params::use_graphs; }; using custom_int2 = test_utils::custom_test_type; @@ -74,7 +77,8 @@ typedef ::testing::Types< DeviceMergeParams>, DeviceMergeParams>, DeviceMergeParams>, - DeviceMergeParams> + DeviceMergeParams, + DeviceMergeParams, true>> RocprimDeviceMergeTestsParams; // size1, size2 @@ -111,6 +115,11 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) const bool debug_synchronous = TestFixture::debug_synchronous; hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for(auto sizes : get_sizes()) { @@ -182,6 +191,11 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) size1 + size2 ); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -196,12 +210,18 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::merge( @@ -212,6 +232,10 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) compare_op, stream, debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -233,10 +257,14 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKey) hipFree(d_keys_input2); hipFree(d_keys_output); hipFree(d_temp_storage); - } - + if (TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + } } + + if (TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) @@ -248,11 +276,15 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) using key_type = typename TestFixture::key_type; using value_type = typename TestFixture::value_type; using compare_op_type = typename TestFixture::compare_op_type; - const bool debug_synchronous = TestFixture::debug_synchronous; using key_value = std::pair; hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for(auto sizes : get_sizes()) { @@ -366,6 +398,11 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) size1 + size2 ); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -378,16 +415,22 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) d_values_input1, d_values_input2, d_values_checking_output, keys_input1.size(), keys_input2.size(), - compare_op, stream, debug_synchronous + compare_op, stream, TestFixture::debug_synchronous ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::merge( @@ -397,9 +440,13 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) d_values_input1, d_values_input2, d_values_checking_output, keys_input1.size(), keys_input2.size(), - compare_op, stream, debug_synchronous + compare_op, stream, TestFixture::debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -438,12 +485,18 @@ TYPED_TEST(RocprimDeviceMergeTests, MergeKeyValue) hipFree(d_values_input2); hipFree(d_values_output); hipFree(d_temp_storage); - } + if (TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + } } + + if (TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } -TEST(RocprimDeviceMergeTests, MergeMismatchedIteratorTypes) +template +void testMergeMismatchedIteratorTypes() { const int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); @@ -481,6 +534,16 @@ TEST(RocprimDeviceMergeTests, MergeMismatchedIteratorTypes) static constexpr bool debug_synchronous = false; + hipStream_t stream = 0; // default + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (UseGraphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + graph = test_utils::createGraphHelper(stream); + } + size_t temp_storage_size_bytes = 0; HIP_CHECK(rocprim::merge(nullptr, temp_storage_size_bytes, @@ -490,14 +553,20 @@ TEST(RocprimDeviceMergeTests, MergeMismatchedIteratorTypes) keys_input1.size(), keys_input1.size(), rocprim::less{}, - hipStreamDefault, + stream, debug_synchronous)); + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temp_storage_size_bytes, 0); void* d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + if (UseGraphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::merge(d_temp_storage, temp_storage_size_bytes, d_keys_input1, @@ -509,6 +578,9 @@ TEST(RocprimDeviceMergeTests, MergeMismatchedIteratorTypes) hipStreamDefault, debug_synchronous)); + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + std::vector keys_output(expected_keys_output.size()); HIP_CHECK(hipMemcpy(keys_output.data(), d_keys_output, @@ -520,4 +592,20 @@ TEST(RocprimDeviceMergeTests, MergeMismatchedIteratorTypes) HIP_CHECK(hipFree(d_temp_storage)); HIP_CHECK(hipFree(d_keys_output)); HIP_CHECK(hipFree(d_keys_input1)); + + if (UseGraphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } +} + +TEST(RocprimDeviceMergeTests, MergeMismatchedIteratorTypes) +{ + testMergeMismatchedIteratorTypes(); +} + +TEST(RocprimDeviceMergeTests, MergeMismatchedIteratorTypesWithGraphs) +{ + testMergeMismatchedIteratorTypes(); } diff --git a/test/rocprim/test_device_merge_sort.cpp b/test/rocprim/test_device_merge_sort.cpp index c9d837fd9..75e5b7932 100644 --- a/test/rocprim/test_device_merge_sort.cpp +++ b/test/rocprim/test_device_merge_sort.cpp @@ -34,13 +34,15 @@ template< class KeyType, class ValueType = KeyType, - class CompareFunction = ::rocprim::less + class CompareFunction = ::rocprim::less, + bool UseGraphs = false > struct DeviceSortParams { using key_type = KeyType; using value_type = ValueType; using compare_function = CompareFunction; + static constexpr bool use_graphs = UseGraphs; }; // --------------------------------------------------------- @@ -55,6 +57,7 @@ class RocprimDeviceSortTests : public ::testing::Test using value_type = typename Params::value_type; using compare_function = typename Params::compare_function; const bool debug_synchronous = false; + bool use_graphs = Params::use_graphs; }; using RocprimDeviceSortTestsParams = ::testing::Types< @@ -74,7 +77,8 @@ using RocprimDeviceSortTestsParams = ::testing::Types< DeviceSortParams>, DeviceSortParams, test_utils::custom_test_type>, DeviceSortParams, - DeviceSortParams>>; + DeviceSortParams>, + DeviceSortParams, true>>; static_assert(std::is_trivially_copyable::value, "Type must be trivially copyable to cover merge sort specialized kernel"); @@ -101,6 +105,11 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) for(size_t size : test_utils::get_sizes(seed_value)) { hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -141,6 +150,11 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) compare_op ); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -153,6 +167,9 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -160,6 +177,9 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::merge_sort( @@ -168,6 +188,10 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) compare_op, stream, debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -190,9 +214,14 @@ TYPED_TEST(RocprimDeviceSortTests, SortKey) hipFree(d_output); } hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } } - } TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) @@ -216,6 +245,11 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) for(size_t size : test_utils::get_sizes(seed_value)) { hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -286,6 +320,11 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) [compare_op](const key_value& a, const key_value& b) { return compare_op(a.first, b.first); } ); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -299,6 +338,9 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -306,6 +348,9 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::merge_sort( @@ -315,6 +360,10 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) compare_op, stream, debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -355,7 +404,12 @@ TYPED_TEST(RocprimDeviceSortTests, SortKeyValue) hipFree(d_values_output); } hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } } - } diff --git a/test/rocprim/test_device_partition.cpp b/test/rocprim/test_device_partition.cpp index 659ece0d0..fdfb0ac77 100644 --- a/test/rocprim/test_device_partition.cpp +++ b/test/rocprim/test_device_partition.cpp @@ -36,7 +36,8 @@ template< class InputType, class OutputType = InputType, class FlagType = unsigned int, - bool UseIdentityIterator = false + bool UseIdentityIterator = false, + bool UseGraphs = false > struct DevicePartitionParams { @@ -44,6 +45,7 @@ struct DevicePartitionParams using output_type = OutputType; using flag_type = FlagType; static constexpr bool use_identity_iterator = UseIdentityIterator; + static constexpr bool use_graphs = UseGraphs; }; template @@ -55,6 +57,7 @@ class RocprimDevicePartitionTests : public ::testing::Test using flag_type = typename Params::flag_type; const bool debug_synchronous = false; static constexpr bool use_identity_iterator = Params::use_identity_iterator; + static constexpr bool use_graphs = Params::use_graphs; }; typedef ::testing::Types< @@ -65,7 +68,8 @@ typedef ::testing::Types< DevicePartitionParams, DevicePartitionParams, DevicePartitionParams, - DevicePartitionParams> + DevicePartitionParams>, + DevicePartitionParams > RocprimDevicePartitionTestsParams; TYPED_TEST_SUITE(RocprimDevicePartitionTests, RocprimDevicePartitionTestsParams); @@ -83,6 +87,11 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) const bool debug_synchronous = TestFixture::debug_synchronous; hipStream_t stream = 0; // default stream + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -128,6 +137,11 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) } std::reverse(expected_rejected.begin(), expected_rejected.end()); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -142,6 +156,9 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) stream, debug_synchronous)); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -149,6 +166,9 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) void* d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK(rocprim::partition( d_temp_storage, @@ -161,6 +181,9 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) stream, debug_synchronous)); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // Check if number of selected value is as expected_selected unsigned int selected_count_output = 0; HIP_CHECK(hipMemcpy(&selected_count_output, @@ -190,9 +213,14 @@ TYPED_TEST(RocprimDevicePartitionTests, Flagged) hipFree(d_output); hipFree(d_selected_count_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + if (TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) @@ -206,6 +234,11 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) const bool debug_synchronous = TestFixture::debug_synchronous; hipStream_t stream = 0; // default stream + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } auto select_op = [] __host__ __device__ (const T& value) -> bool { @@ -230,6 +263,11 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) 0 ); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -243,10 +281,16 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) stream, debug_synchronous)); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // allocate temporary storage void* d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK(rocprim::partition(d_temp_storage, temp_storage_size_bytes, @@ -257,8 +301,11 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) select_op, stream, debug_synchronous)); - HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + + HIP_CHECK(hipDeviceSynchronize()); ASSERT_FALSE(out_of_bounds.get()); // Check if number of selected value is 0 @@ -274,6 +321,12 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateEmptyInput) hipFree(d_output); hipFree(d_selected_count_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } TYPED_TEST(RocprimDevicePartitionTests, Predicate) @@ -288,6 +341,11 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) const bool debug_synchronous = TestFixture::debug_synchronous; hipStream_t stream = 0; // default stream + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } auto select_op = [] __host__ __device__ (const T& value) -> bool { @@ -334,6 +392,11 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) } std::reverse(expected_rejected.begin(), expected_rejected.end()); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -348,6 +411,9 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) stream, debug_synchronous)); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -355,6 +421,9 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) void* d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK(rocprim::partition( d_temp_storage, @@ -366,6 +435,10 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) select_op, stream, debug_synchronous)); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected_selected @@ -375,7 +448,7 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) sizeof(unsigned int), hipMemcpyDeviceToHost)); ASSERT_EQ(selected_count_output, expected_selected.size()); - + // Check if output values are as expected_selected std::vector output(input.size()); HIP_CHECK(hipMemcpy(output.data(), @@ -396,8 +469,14 @@ TYPED_TEST(RocprimDevicePartitionTests, Predicate) hipFree(d_output); hipFree(d_selected_count_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + + if (TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) @@ -412,6 +491,11 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) const bool debug_synchronous = TestFixture::debug_synchronous; hipStream_t stream = 0; // default stream + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } auto select_op = [] __host__ __device__(const T& value) -> bool { @@ -464,6 +548,11 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) } } + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -478,6 +567,9 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) stream, debug_synchronous)); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -485,6 +577,9 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) void* d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK(rocprim::partition_two_way( d_temp_storage, @@ -497,6 +592,10 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) select_op, stream, debug_synchronous)); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected @@ -529,8 +628,14 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateTwoWay) hipFree(d_rejected); hipFree(d_selected_count_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + + if (TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } namespace { @@ -556,7 +661,13 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; const bool debug_synchronous = TestFixture::debug_synchronous; - const hipStream_t stream = 0; // default stream + hipStream_t stream = 0; // default stream + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + const std::vector> limit_pairs{ { static_cast(30), static_cast(60) }, // all sections may contain items { static_cast(0), static_cast(60) }, // first section is empty @@ -623,6 +734,11 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) return result; }(); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -641,6 +757,9 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) stream, debug_synchronous)); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -648,6 +767,9 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) void* d_temp_storage = nullptr; HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK(rocprim::partition_three_way( d_temp_storage, @@ -663,6 +785,10 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) second_op, stream, debug_synchronous)); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected_selected @@ -711,9 +837,15 @@ TYPED_TEST(RocprimDevicePartitionTests, PredicateThreeWay) hipFree(d_unselected_output); hipFree(d_selected_counts); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } } + + if (TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } namespace @@ -949,24 +1081,34 @@ struct modulo_predicate } // namespace -struct RocprimDevicePartitionLargeInputTests : public ::testing::TestWithParam +struct RocprimDevicePartitionLargeInputTests : public ::testing::TestWithParam> {}; INSTANTIATE_TEST_SUITE_P(RocprimDevicePartitionLargeInputTest, RocprimDevicePartitionLargeInputTests, - ::testing::Values(2, 2048, 38713)); + ::testing::Values(std::make_pair(2, false), // params: size, use_graphs + std::make_pair(2048, false), + std::make_pair(38713, false), + std::make_pair(38713, true))); TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartition) { static constexpr bool debug_synchronous = false; - static constexpr hipStream_t stream = 0; + auto param = GetParam(); + const size_t modulo = std::get<0>(param); + const bool use_graphs = std::get<1>(param); + + hipStream_t stream = 0; // default + if (use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - const auto modulo = GetParam(); - for(const auto size : test_utils::get_large_sizes(std::random_device{}())) { // limit the running time of the test @@ -988,6 +1130,11 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartition) size_t* d_count_output{}; HIP_CHECK(test_common_utils::hipMallocHelper(&d_count_output, sizeof(*d_count_output))); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (use_graphs) + graph = test_utils::createGraphHelper(stream); + void* d_temporary_storage{}; size_t temporary_storage_size{}; HIP_CHECK(rocprim::partition(d_temporary_storage, @@ -1000,9 +1147,15 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartition) stream, debug_synchronous)); + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_NE(0, temporary_storage_size); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_size)); + if (use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::partition(d_temporary_storage, temporary_storage_size, input_iterator, @@ -1013,6 +1166,9 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartition) stream, debug_synchronous)); + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + size_t count_output{}; HIP_CHECK(hipMemcpyWithStream(&count_output, d_count_output, @@ -1035,20 +1191,33 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartition) HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_count_output)); HIP_CHECK(hipFree(d_incorrect_flag)); + + if (use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } + + if (use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionTwoWay) { static constexpr bool debug_synchronous = false; - static constexpr hipStream_t stream = 0; + auto param = GetParam(); + const size_t modulo = std::get<0>(param); + const bool use_graphs = std::get<1>(param); + + hipStream_t stream = 0; // default + if (use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - const auto modulo = GetParam(); - for(const auto size : test_utils::get_large_sizes(std::random_device{}())) { // limit the running time of the test @@ -1079,6 +1248,11 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionTwoWay) size_t* d_count_output{}; HIP_CHECK(test_common_utils::hipMallocHelper(&d_count_output, sizeof(*d_count_output))); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (use_graphs) + graph = test_utils::createGraphHelper(stream); + void* d_temporary_storage{}; size_t temporary_storage_size{}; HIP_CHECK(rocprim::partition_two_way(d_temporary_storage, @@ -1092,9 +1266,15 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionTwoWay) stream, debug_synchronous)); + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_NE(0, temporary_storage_size); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_size)); + if (use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::partition_two_way(d_temporary_storage, temporary_storage_size, input_iterator, @@ -1106,6 +1286,9 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionTwoWay) stream, debug_synchronous)); + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + size_t count_output{}; HIP_CHECK(hipMemcpyWithStream(&count_output, d_count_output, @@ -1136,19 +1319,33 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionTwoWay) HIP_CHECK(hipFree(d_count_output)); HIP_CHECK(hipFree(d_incorrect_select_flag)); HIP_CHECK(hipFree(d_incorrect_reject_flag)); + + if (use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } + + if (use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionThreeWay) { static constexpr bool debug_synchronous = false; - static constexpr hipStream_t stream = 0; + auto param = GetParam(); + const bool use_graphs = std::get<1>(param); + + hipStream_t stream = 0; // default + if (use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - const auto modulo_a = GetParam(); + const auto modulo_a = std::get<0>(param); const auto modulo_b = modulo_a + 1; for(const auto size : test_utils::get_large_sizes(std::random_device{}())) @@ -1176,6 +1373,11 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionThreeWay) size_t* d_count_output{}; HIP_CHECK(test_common_utils::hipMallocHelper(&d_count_output, 2 * sizeof(*d_count_output))); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (use_graphs) + graph = test_utils::createGraphHelper(stream); + void* d_temporary_storage{}; size_t temporary_storage_size{}; HIP_CHECK(rocprim::partition_three_way(d_temporary_storage, @@ -1191,9 +1393,15 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionThreeWay) stream, debug_synchronous)); + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_NE(0, temporary_storage_size); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_size)); + if (use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::partition_three_way(d_temporary_storage, temporary_storage_size, input_iterator, @@ -1207,6 +1415,9 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionThreeWay) stream, debug_synchronous)); + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + size_t count_output[2]{}; HIP_CHECK(hipMemcpyWithStream(&count_output, d_count_output, @@ -1233,5 +1444,11 @@ TEST_P(RocprimDevicePartitionLargeInputTests, LargeInputPartitionThreeWay) HIP_CHECK(hipFree(d_temporary_storage)); HIP_CHECK(hipFree(d_count_output)); HIP_CHECK(hipFree(d_incorrect_flag)); + + if (use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } + + if (use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } diff --git a/test/rocprim/test_device_radix_sort.cpp.in b/test/rocprim/test_device_radix_sort.cpp.in index b56e85889..4d0931246 100644 --- a/test/rocprim/test_device_radix_sort.cpp.in +++ b/test/rocprim/test_device_radix_sort.cpp.in @@ -31,7 +31,7 @@ #cmakedefine ROCPRIM_TEST_SLICE @ROCPRIM_TEST_SLICE@ #if ROCPRIM_TEST_SUITE_SLICE == 0 - TYPED_TEST_P(SUITE, SortKeys ) { sort_keys(); } + TYPED_TEST_P(SUITE, SortKeys ) { sort_keys(); } REGISTER_TYPED_TEST_SUITE_P(SUITE, SortKeys); #elif ROCPRIM_TEST_SUITE_SLICE == 1 TYPED_TEST_P(SUITE, SortPairs ) { sort_pairs(); } @@ -46,6 +46,7 @@ #if ROCPRIM_TEST_SLICE == 0 TEST(SUITE, SortKeysOver4G) { sort_keys_over_4g(); } + TEST(SUITE, SortKeysOver4GWithGraphs) { sort_keys_over_4g(); } #endif #if ROCPRIM_TEST_TYPE_SLICE == 0 @@ -75,6 +76,7 @@ INSTANTIATE(params) INSTANTIATE(params) INSTANTIATE(params) + #elif ROCPRIM_TEST_TYPE_SLICE == 1 INSTANTIATE(params) INSTANTIATE(params) @@ -99,4 +101,7 @@ INSTANTIATE(params) INSTANTIATE(params) INSTANTIATE(params) + + // test with graphs + INSTANTIATE(params) #endif diff --git a/test/rocprim/test_device_radix_sort.hpp b/test/rocprim/test_device_radix_sort.hpp index 0b7e81620..668c16532 100644 --- a/test/rocprim/test_device_radix_sort.hpp +++ b/test/rocprim/test_device_radix_sort.hpp @@ -38,7 +38,8 @@ template + bool CheckLargeSizes = false, + bool UseGraphs = false> struct params { using key_type = Key; @@ -47,6 +48,7 @@ struct params static constexpr unsigned int start_bit = StartBit; static constexpr unsigned int end_bit = EndBit; static constexpr bool check_large_sizes = CheckLargeSizes; + static constexpr bool use_graphs = UseGraphs; }; template @@ -72,6 +74,11 @@ inline void sort_keys() constexpr bool check_large_sizes = TestFixture::params::check_large_sizes; hipStream_t stream = 0; + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const bool debug_synchronous = false; @@ -144,6 +151,11 @@ inline void sort_keys() rocprim::default_config, 1024 * 512>; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temporary_storage_bytes; HIP_CHECK(rocprim::radix_sort_keys(nullptr, temporary_storage_bytes, @@ -153,12 +165,18 @@ inline void sort_keys() start_bit, end_bit)); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); void* d_temporary_storage; HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + if(descending) { HIP_CHECK(rocprim::radix_sort_keys_desc(d_temporary_storage, @@ -184,6 +202,9 @@ inline void sort_keys() debug_synchronous)); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + std::vector keys_output(size); HIP_CHECK(hipMemcpy(keys_output.data(), d_keys_output, @@ -197,9 +218,15 @@ inline void sort_keys() HIP_CHECK(hipFree(d_keys_output)); } + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected)); } } + + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } template @@ -217,6 +244,11 @@ inline void sort_pairs() constexpr bool check_large_sizes = TestFixture::params::check_large_sizes; hipStream_t stream = 0; + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const bool debug_synchronous = false; @@ -327,6 +359,11 @@ inline void sort_pairs() 4>, 1024 * 512>; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + void* d_temporary_storage = nullptr; size_t temporary_storage_bytes; HIP_CHECK(rocprim::radix_sort_pairs(d_temporary_storage, @@ -339,11 +376,17 @@ inline void sort_pairs() start_bit, end_bit)); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + if(descending) { HIP_CHECK(rocprim::radix_sort_pairs_desc(d_temporary_storage, @@ -373,6 +416,9 @@ inline void sort_pairs() debug_synchronous)); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + std::vector keys_output(size); HIP_CHECK(hipMemcpy(keys_output.data(), d_keys_output, @@ -394,10 +440,16 @@ inline void sort_pairs() HIP_CHECK(hipFree(d_values_output)); } + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, keys_expected)); ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected)); } } + + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } template @@ -414,7 +466,12 @@ inline void sort_keys_double_buffer() constexpr bool check_large_sizes = TestFixture::params::check_large_sizes; hipStream_t stream = 0; - + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + const bool debug_synchronous = false; for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) @@ -470,6 +527,11 @@ inline void sort_keys_double_buffer() rocprim::double_buffer d_keys(d_keys_input, d_keys_output); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temporary_storage_bytes; HIP_CHECK(rocprim::radix_sort_keys(nullptr, temporary_storage_bytes, @@ -478,12 +540,18 @@ inline void sort_keys_double_buffer() start_bit, end_bit)); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); void* d_temporary_storage; HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + if(descending) { HIP_CHECK(rocprim::radix_sort_keys_desc(d_temporary_storage, @@ -507,6 +575,9 @@ inline void sort_keys_double_buffer() debug_synchronous)); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipFree(d_temporary_storage)); std::vector keys_output(size); @@ -518,9 +589,15 @@ inline void sort_keys_double_buffer() HIP_CHECK(hipFree(d_keys_input)); HIP_CHECK(hipFree(d_keys_output)); + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, expected)); } } + + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } template @@ -538,7 +615,12 @@ inline void sort_pairs_double_buffer() constexpr bool check_large_sizes = TestFixture::params::check_large_sizes; hipStream_t stream = 0; - + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + const bool debug_synchronous = false; for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) @@ -623,6 +705,11 @@ inline void sort_pairs_double_buffer() rocprim::double_buffer d_keys(d_keys_input, d_keys_output); rocprim::double_buffer d_values(d_values_input, d_values_output); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + void* d_temporary_storage = nullptr; size_t temporary_storage_bytes; HIP_CHECK(rocprim::radix_sort_pairs(d_temporary_storage, @@ -633,11 +720,17 @@ inline void sort_pairs_double_buffer() start_bit, end_bit)); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + if(descending) { HIP_CHECK(rocprim::radix_sort_pairs_desc(d_temporary_storage, @@ -663,6 +756,9 @@ inline void sort_pairs_double_buffer() debug_synchronous)); } + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipFree(d_temporary_storage)); std::vector keys_output(size); @@ -682,21 +778,34 @@ inline void sort_pairs_double_buffer() HIP_CHECK(hipFree(d_values_input)); HIP_CHECK(hipFree(d_values_output)); + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(keys_output, keys_expected)); ASSERT_NO_FATAL_FAILURE(test_utils::assert_bit_eq(values_output, values_expected)); } } + + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } +template inline void sort_keys_over_4g() { using key_type = uint8_t; constexpr unsigned int start_bit = 0; - constexpr unsigned int end_bit = 8ull * sizeof(key_type); - constexpr hipStream_t stream = 0; + constexpr unsigned int end_bit = 8ull * sizeof(key_type); constexpr bool debug_synchronous = false; constexpr size_t size = (1ull << 32) + 32; constexpr size_t number_of_possible_keys = 1ull << (8ull * sizeof(key_type)); + hipStream_t stream = 0; + if (UseGraphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + assert(std::is_unsigned::value); std::vector histogram(number_of_possible_keys, 0); const int seed_value = rand(); @@ -723,6 +832,11 @@ inline void sort_keys_over_4g() key_type_storage_bytes, hipMemcpyHostToDevice)); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (UseGraphs) + graph = test_utils::createGraphHelper(stream); + size_t temporary_storage_bytes; HIP_CHECK(rocprim::radix_sort_keys(nullptr, temporary_storage_bytes, @@ -734,10 +848,13 @@ inline void sort_keys_over_4g() stream, debug_synchronous)); + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); - hipDeviceProp_t prop; - HIP_CHECK(hipGetDeviceProperties(&prop, device_id)); + hipDeviceProp_t prop; + HIP_CHECK(hipGetDeviceProperties(&prop, device_id)); size_t total_storage_bytes = key_type_storage_bytes + temporary_storage_bytes; if (total_storage_bytes > (static_cast(prop.totalGlobalMem * 0.90))) { @@ -749,6 +866,9 @@ inline void sort_keys_over_4g() void* d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (UseGraphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::radix_sort_keys(d_temporary_storage, temporary_storage_bytes, d_keys_input_output, @@ -759,6 +879,9 @@ inline void sort_keys_over_4g() stream, debug_synchronous)); + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + std::vector output(keys_input.size()); HIP_CHECK(hipMemcpy(output.data(), d_keys_input_output, @@ -778,6 +901,12 @@ inline void sort_keys_over_4g() HIP_CHECK(hipFree(d_keys_input_output)); HIP_CHECK(hipFree(d_temporary_storage)); + + if (UseGraphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } #endif // TEST_DEVICE_RADIX_SORT_HPP_ diff --git a/test/rocprim/test_device_reduce.cpp b/test/rocprim/test_device_reduce.cpp index 324629f88..3cb7458ea 100644 --- a/test/rocprim/test_device_reduce.cpp +++ b/test/rocprim/test_device_reduce.cpp @@ -38,7 +38,8 @@ template + bra Algo = bra::default_algorithm, + bool UseGraphs = false> struct DeviceReduceParams { static constexpr bra algo = Algo; @@ -47,6 +48,7 @@ struct DeviceReduceParams // Tests output iterator with void value_type (OutputIterator concept) static constexpr bool use_identity_iterator = UseIdentityIterator; static constexpr size_t size_limit = SizeLimit; + static constexpr bool use_graphs = UseGraphs; }; // clang-format off @@ -83,6 +85,7 @@ class RocprimDeviceReduceTests : public ::testing::Test const bool debug_synchronous = false; static constexpr bool use_identity_iterator = Params::use_identity_iterator; static constexpr size_t size_limit = Params::size_limit; + const bool use_graphs = Params::use_graphs; }; template @@ -104,7 +107,8 @@ typedef ::testing::Types< // DeviceReduceParams, DeviceReduceParams, DeviceReduceParams, test_utils::custom_test_type>, - DeviceReduceParams, test_utils::custom_test_type>> + DeviceReduceParams, test_utils::custom_test_type>, + DeviceReduceParams> RocprimDeviceReduceTestsParams; typedef ::testing::Types< @@ -130,12 +134,22 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceEmptyInput) using Config = size_limit_config_t; hipStream_t stream = 0; // default stream + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } U * d_output; HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, sizeof(U))); const U initial_value = U(1234); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temp_storage_size_bytes; // Get size of d_temp_storage HIP_CHECK( @@ -148,9 +162,15 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceEmptyInput) ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + void * d_temp_storage = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::reduce( @@ -161,6 +181,10 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceEmptyInput) 0, rocprim::minimum(), stream, debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipDeviceSynchronize()); U output; @@ -175,6 +199,12 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceEmptyInput) hipFree(d_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } TYPED_TEST(RocprimDeviceReduceTests, ReduceSum) @@ -207,6 +237,12 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceSum) } hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -232,6 +268,12 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceSum) // fix for custom_test_type case with size == 0 if(size == 0) expected = U(); + + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -245,6 +287,9 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceSum) ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -252,6 +297,9 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceSum) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::reduce( @@ -261,6 +309,10 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceSum) input.size(), rocprim::plus(), stream, debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -281,120 +333,14 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceSum) hipFree(d_input); hipFree(d_output); hipFree(d_temp_storage); - } - } - -} - -TYPED_TEST(RocprimDeviceReduceTests, ReduceMinimum) -{ - int device_id = test_common_utils::obtain_device_from_ctest(); - SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); - HIP_CHECK(hipSetDevice(device_id)); - - using T = typename TestFixture::input_type; - using U = typename TestFixture::output_type; - - using binary_op_type = rocprim::minimum; - - const bool debug_synchronous = TestFixture::debug_synchronous; - static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; - using Config = size_limit_config_t; - - for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) - { - unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; - SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); - - for(auto size : test_utils::get_sizes(seed_value)) - { - hipStream_t stream = 0; // default - - SCOPED_TRACE(testing::Message() << "with size = " << size); - - // Generate data - std::vector input = test_utils::get_random_data(size, 1, 100, seed_value); - std::vector output(1, U(0)); - - T * d_input; - U * d_output; - HIP_CHECK(test_common_utils::hipMallocHelper(&d_input, input.size() * sizeof(T))); - HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, output.size() * sizeof(U))); - HIP_CHECK( - hipMemcpy( - d_input, input.data(), - input.size() * sizeof(T), - hipMemcpyHostToDevice - ) - ); - HIP_CHECK(hipDeviceSynchronize()); - - // reduce function - binary_op_type min_op; - - // Calculate expected results on host - U expected = U(test_utils::numeric_limits::max()); - for(unsigned int i = 0; i < input.size(); i++) + + if (TestFixture::use_graphs) { - expected = min_op(expected, input[i]); + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); } - - // temp storage - size_t temp_storage_size_bytes; - void * d_temp_storage = nullptr; - // Get size of d_temp_storage - HIP_CHECK( - rocprim::reduce( - d_temp_storage, temp_storage_size_bytes, - d_input, - test_utils::wrap_in_identity_iterator(d_output), - test_utils::numeric_limits::max(), input.size(), rocprim::minimum(), stream, debug_synchronous - ) - ); - - // temp_storage_size_bytes must be >0 - ASSERT_GT(temp_storage_size_bytes, 0); - - // allocate temporary storage - HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); - HIP_CHECK(hipDeviceSynchronize()); - - // Run - HIP_CHECK( - rocprim::reduce( - d_temp_storage, temp_storage_size_bytes, - d_input, - test_utils::wrap_in_identity_iterator(d_output), - test_utils::numeric_limits::max(), input.size(), rocprim::minimum(), stream, debug_synchronous - ) - ); - HIP_CHECK(hipGetLastError()); - HIP_CHECK(hipDeviceSynchronize()); - - // Copy output to host - HIP_CHECK( - hipMemcpy( - output.data(), d_output, - output.size() * sizeof(U), - hipMemcpyDeviceToHost - ) - ); - HIP_CHECK(hipDeviceSynchronize()); - - // Check if output values are as expected - ASSERT_NO_FATAL_FAILURE(test_utils::assert_near( - output[0], - expected, - std::is_same::value - ? 0 - : std::max(test_utils::precision, test_utils::precision))); - - hipFree(d_input); - hipFree(d_output); - hipFree(d_temp_storage); } } - } template< @@ -437,6 +383,11 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceArgMinimum) for(auto size : test_utils::get_sizes(seed_value)) { hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -472,6 +423,11 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceArgMinimum) expected = reduce_op(expected, input[i]); } + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -485,6 +441,9 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceArgMinimum) ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -492,6 +451,9 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceArgMinimum) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::reduce( @@ -501,6 +463,10 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceArgMinimum) max, input.size(), reduce_op, stream, debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -521,12 +487,18 @@ TYPED_TEST(RocprimDeviceReduceTests, ReduceArgMinimum) hipFree(d_input); hipFree(d_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } } - } -TEST(RocprimDeviceReduceTests, LargeIndices) +template +void testLargeIndices() { const int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); @@ -536,7 +508,12 @@ TEST(RocprimDeviceReduceTests, LargeIndices) using Iterator = rocprim::counting_iterator; const bool debug_synchronous = false; - const hipStream_t stream = 0; // default + hipStream_t stream = 0; // default + if (use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -553,6 +530,11 @@ TEST(RocprimDeviceReduceTests, LargeIndices) T* d_output = nullptr; HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, sizeof(T))); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes = 0; void* d_temp_storage = nullptr; @@ -566,10 +548,16 @@ TEST(RocprimDeviceReduceTests, LargeIndices) stream, debug_synchronous)); + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK(rocprim::reduce(d_temp_storage, temp_storage_size_bytes, @@ -579,6 +567,10 @@ TEST(RocprimDeviceReduceTests, LargeIndices) rocprim::plus {}, stream, debug_synchronous)); + + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -595,8 +587,24 @@ TEST(RocprimDeviceReduceTests, LargeIndices) hipFree(d_temp_storage); hipFree(d_output); + + if (use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + + if (use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); +} + +TEST(RocprimDeviceReduceTests, LargeIndices) +{ + testLargeIndices<>(); +} + +TEST(RocprimDeviceReduceTests, LargeIndicesWithGraphs) +{ + testLargeIndices(); } TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunction) @@ -628,6 +636,11 @@ TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunctio } hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -653,6 +666,11 @@ TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunctio // Calculate expected results on host mathematically (instead of using reduce on host) U expected = static_cast(static_cast(size) * static_cast(lowest)); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void* d_temp_storage = nullptr; @@ -667,6 +685,9 @@ TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunctio stream, debug_synchronous)); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -674,6 +695,9 @@ TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunctio HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK(rocprim::reduce( d_temp_storage, @@ -684,6 +708,10 @@ TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunctio rocprim::plus(), stream, debug_synchronous)); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -698,5 +726,143 @@ TYPED_TEST(RocprimDeviceReducePrecisionTests, ReduceSumInputEqualExponentFunctio hipFree(d_input); hipFree(d_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } + } +} + +TYPED_TEST(RocprimDeviceReduceTests, ReduceMinimum) +{ + int device_id = test_common_utils::obtain_device_from_ctest(); + SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); + HIP_CHECK(hipSetDevice(device_id)); + + using T = typename TestFixture::input_type; + using U = typename TestFixture::output_type; + + using binary_op_type = rocprim::minimum; + + static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; + using Config = size_limit_config_t; + + for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) + { + unsigned int seed_value = seed_index < random_seeds_count ? rand() : seeds[seed_index - random_seeds_count]; + SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); + + for(auto size : test_utils::get_sizes(seed_value)) + { + hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } + + SCOPED_TRACE(testing::Message() << "with size = " << size); + + // Generate data + std::vector input = test_utils::get_random_data(size, 1, 100, seed_value); + std::vector output(1, U(0)); + + T * d_input; + U * d_output; + HIP_CHECK(test_common_utils::hipMallocHelper(&d_input, input.size() * sizeof(T))); + HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, output.size() * sizeof(U))); + HIP_CHECK( + hipMemcpy( + d_input, input.data(), + input.size() * sizeof(T), + hipMemcpyHostToDevice + ) + ); + HIP_CHECK(hipDeviceSynchronize()); + + // reduce function + binary_op_type min_op; + + // Calculate expected results on host + U expected = U(test_utils::numeric_limits::max()); + for(unsigned int i = 0; i < input.size(); i++) + { + expected = min_op(expected, input[i]); + } + + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + + // Get size of d_temp_storage + size_t temp_storage_size_bytes; + void * d_temp_storage = nullptr; + HIP_CHECK( + rocprim::reduce( + d_temp_storage, temp_storage_size_bytes, + d_input, + test_utils::wrap_in_identity_iterator(d_output), + test_utils::numeric_limits::max(), input.size(), rocprim::minimum(), stream, TestFixture::debug_synchronous + ) + ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + + // temp_storage_size_bytes must be >0 + ASSERT_GT(temp_storage_size_bytes, 0); + + // allocate temporary storage + HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + HIP_CHECK(hipDeviceSynchronize()); + + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + + // Run + HIP_CHECK( + rocprim::reduce( + d_temp_storage, temp_storage_size_bytes, + d_input, + test_utils::wrap_in_identity_iterator(d_output), + test_utils::numeric_limits::max(), input.size(), rocprim::minimum(), stream, TestFixture::debug_synchronous + ) + ); + + // Copy output to host + HIP_CHECK( + hipMemcpyAsync( + output.data(), d_output, + output.size() * sizeof(U), + hipMemcpyDeviceToHost, stream + ) + ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + + HIP_CHECK(hipDeviceSynchronize()); + + // Check if output values are as expected + ASSERT_NO_FATAL_FAILURE(test_utils::assert_near( + output[0], + expected, + std::is_same::value + ? 0 + : std::max(test_utils::precision, test_utils::precision))); + + hipFree(d_input); + hipFree(d_output); + hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } + } } } diff --git a/test/rocprim/test_device_reduce_by_key.cpp b/test/rocprim/test_device_reduce_by_key.cpp index c0a9d8285..9c0f4dd43 100644 --- a/test/rocprim/test_device_reduce_by_key.cpp +++ b/test/rocprim/test_device_reduce_by_key.cpp @@ -42,7 +42,8 @@ template< class Aggregate = Value, class KeyCompareFunction = ::rocprim::equal_to, // Tests output iterator with void value_type (OutputIterator concept) - bool UseIdentityIterator = false + bool UseIdentityIterator = false, + bool UseGraphs = false > struct params { @@ -54,6 +55,7 @@ struct params using aggregate_type = Aggregate; using key_compare_op = KeyCompareFunction; static constexpr bool use_identity_iterator = UseIdentityIterator; + static constexpr bool use_graphs = UseGraphs; }; template @@ -107,7 +109,8 @@ typedef ::testing::Types< params, 1000, 10000, long long>, params, 1000, 50000>, params, 100000, 100000>, - params, unsigned long, rocprim::plus<>, 69, 420> + params, unsigned long, rocprim::plus<>, 69, 420>, + params, 1, 10, int, ::rocprim::equal_to, false, true> > Params; // clang-format on @@ -158,6 +161,11 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) SCOPED_TRACE(testing::Message() << "with size = " << size); hipStream_t stream = 0; // default + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const bool use_unique_keys = bool(test_utils::get_random_value(0, 1, seed_value)); @@ -248,6 +256,11 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) size_t temporary_storage_bytes; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + HIP_CHECK( rocprim::reduce_by_key( nullptr, temporary_storage_bytes, @@ -260,11 +273,17 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) ) ); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK( rocprim::reduce_by_key( d_temporary_storage, temporary_storage_bytes, @@ -276,6 +295,9 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) ) ); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipFree(d_temporary_storage)); std::vector unique_output(unique_count_expected); @@ -309,16 +331,21 @@ TYPED_TEST(RocprimDeviceReduceByKey, ReduceByKey) HIP_CHECK(hipFree(d_aggregates_output)); HIP_CHECK(hipFree(d_unique_count_output)); + if (TestFixture::params::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } + ASSERT_EQ(unique_count_output[0], unique_count_expected); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(unique_output, unique_expected)); ASSERT_NO_FATAL_FAILURE(test_utils::assert_eq(aggregates_output, aggregates_expected)); } } - } -template +template void large_indices_reduce_by_key() { int device_id = test_common_utils::obtain_device_from_ctest(); @@ -334,6 +361,11 @@ void large_indices_reduce_by_key() ::rocprim::equal_to key_compare_op; hipStream_t stream = 0; // default + if (use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for(size_t size : test_utils::get_large_sizes(42)) { @@ -369,6 +401,11 @@ void large_indices_reduce_by_key() size_t temporary_storage_bytes; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (use_graphs) + graph = test_utils::createGraphHelper(stream); + HIP_CHECK(rocprim::reduce_by_key(nullptr, temporary_storage_bytes, d_keys_input, @@ -382,12 +419,18 @@ void large_indices_reduce_by_key() stream, debug_synchronous)); + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); void* d_temporary_storage; HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::reduce_by_key(d_temporary_storage, temporary_storage_bytes, d_keys_input, @@ -401,6 +444,9 @@ void large_indices_reduce_by_key() stream, debug_synchronous)); + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipFree(d_temporary_storage)); std::vector unique_output(unique_count_expected); @@ -423,6 +469,9 @@ void large_indices_reduce_by_key() HIP_CHECK(hipFree(d_aggregates_output)); HIP_CHECK(hipFree(d_unique_count_output)); + if (use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + ASSERT_EQ(unique_count_output[0], unique_count_expected); size_t total_size = 0; @@ -438,6 +487,9 @@ void large_indices_reduce_by_key() ASSERT_EQ(last_idx, unique_output[last_idx]); ASSERT_EQ(value_type(size - total_size), aggregates_output[last_idx]); } + + if (use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TEST(RocprimDeviceReduceByKey, LargeIndicesReduceByKeySmallValueType) @@ -451,7 +503,13 @@ TEST(RocprimDeviceReduceByKey, LargeIndicesReduceByKeyLargeValueType) large_indices_reduce_by_key>(); } -template +TEST(RocprimDeviceReduceByKey, LargeIndicesReduceByKeyLargeValueTypeWithGraphs) +{ + // large value type to test TilesPerBlock > 1 + large_indices_reduce_by_key, true>(); +} + +template void large_segment_count_reduce_by_key() { int device_id = test_common_utils::obtain_device_from_ctest(); @@ -466,6 +524,11 @@ void large_segment_count_reduce_by_key() ::rocprim::equal_to key_compare_op; hipStream_t stream = 0; // default + if (use_graphs) + { + // Default stream does not support hipGraphs + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for(size_t size : test_utils::get_large_sizes(42)) { @@ -485,8 +548,12 @@ void large_segment_count_reduce_by_key() HIP_CHECK(test_common_utils::hipMallocHelper(&d_unique_count_output, sizeof(*d_unique_count_output))); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temporary_storage_bytes; - HIP_CHECK(rocprim::reduce_by_key(nullptr, temporary_storage_bytes, d_keys_input, @@ -499,13 +566,18 @@ void large_segment_count_reduce_by_key() key_compare_op, stream, debug_synchronous)); - + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); void* d_temporary_storage; HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::reduce_by_key(d_temporary_storage, temporary_storage_bytes, d_keys_input, @@ -518,7 +590,9 @@ void large_segment_count_reduce_by_key() key_compare_op, stream, debug_synchronous)); - + if (use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipFree(d_temporary_storage)); size_t unique_count_output; @@ -530,7 +604,13 @@ void large_segment_count_reduce_by_key() HIP_CHECK(hipFree(d_unique_count_output)); ASSERT_EQ(unique_count_output, unique_count_expected); + + if (use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } + + if (use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TEST(RocprimDeviceReduceByKey, LargeSegmentCountReduceByKeySmallValueType) @@ -544,6 +624,11 @@ TEST(RocprimDeviceReduceByKey, LargeSegmentCountReduceByKeyLargeValueType) large_segment_count_reduce_by_key>(); } +TEST(RocprimDeviceReduceByKey, GraphReduceByKey) +{ + large_segment_count_reduce_by_key(); +} + TEST(RocprimDeviceReduceByKey, ReduceByNonEqualKeys) { int device_id = test_common_utils::obtain_device_from_ctest(); diff --git a/test/rocprim/test_device_scan.cpp b/test/rocprim/test_device_scan.cpp index 4277b1cd5..3dc6ab7cd 100644 --- a/test/rocprim/test_device_scan.cpp +++ b/test/rocprim/test_device_scan.cpp @@ -70,7 +70,8 @@ template + typename ConfigHelper = default_config_helper, + bool UseGraphs = false> struct DeviceScanParams { using input_type = InputType; @@ -78,6 +79,7 @@ struct DeviceScanParams using scan_op_type = ScanOp; static constexpr bool use_identity_iterator = UseIdentityIteratorIfSupported; using config_helper = ConfigHelper; + static constexpr bool use_graphs = UseGraphs; }; // --------------------------------------------------------- @@ -94,6 +96,7 @@ class RocprimDeviceScanTests : public ::testing::Test const bool debug_synchronous = false; static constexpr bool use_identity_iterator = Params::use_identity_iterator; using config_helper = typename Params::config_helper; + bool use_graphs = Params::use_graphs; }; typedef ::testing::Types< @@ -129,7 +132,9 @@ typedef ::testing::Types< true>, DeviceScanParams>, DeviceScanParams>, - DeviceScanParams>> + DeviceScanParams>, + // With graphs + DeviceScanParams, false, default_config_helper, true>> RocprimDeviceScanTestsParams; // use float for accumulation of bfloat16 and half inputs if operator is plus @@ -158,6 +163,11 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) HIP_CHECK(hipSetDevice(device_id)); hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } U * d_output; HIP_CHECK(test_common_utils::hipMallocHelper(&d_output, sizeof(U))); @@ -176,6 +186,11 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) rocprim::make_constant_iterator(T(345)), [] (T in) { return static_cast(in); }); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -188,9 +203,15 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::inclusive_scan( @@ -199,6 +220,10 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) 0, scan_op, stream, debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -206,6 +231,12 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanEmptyInput) hipFree(d_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) @@ -228,7 +259,6 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) // therefore the only source of error is precision of operation itself constexpr float single_op_precision = is_plus_op::value ? test_utils::precision : 0; - const bool debug_synchronous = TestFixture::debug_synchronous; static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; using Config = typename TestFixture::config_helper::template type; @@ -252,6 +282,11 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) break; } hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -285,6 +320,11 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) auto input_iterator = rocprim::make_transform_iterator( d_input, [] (T in) { return static_cast(in); }); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -293,10 +333,13 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) rocprim::inclusive_scan( d_temp_storage, temp_storage_size_bytes, input_iterator, test_utils::wrap_in_identity_iterator(d_output), - input.size(), scan_op, stream, debug_synchronous + input.size(), scan_op, stream, TestFixture::debug_synchronous ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -304,14 +347,21 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::inclusive_scan( d_temp_storage, temp_storage_size_bytes, input_iterator, test_utils::wrap_in_identity_iterator(d_output), - input.size(), scan_op, stream, debug_synchronous + input.size(), scan_op, stream, TestFixture::debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -332,9 +382,14 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScan) hipFree(d_input); hipFree(d_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } } - } TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) @@ -382,6 +437,11 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) break; } hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -417,6 +477,11 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) auto input_iterator = rocprim::make_transform_iterator( d_input, [] (T in) { return static_cast(in); }); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -429,6 +494,9 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -436,6 +504,9 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::exclusive_scan( @@ -444,6 +515,10 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) initial_value, input.size(), scan_op, stream, debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -464,9 +539,14 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScan) hipFree(d_input); hipFree(d_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } } - } TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) @@ -507,6 +587,11 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) break; } hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -564,6 +649,11 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) auto input_iterator = rocprim::make_transform_iterator( d_input, [] (T in) { return static_cast(in); }); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -575,6 +665,9 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -582,6 +675,9 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::inclusive_scan_by_key( @@ -589,6 +685,10 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) d_output, input.size(), scan_op, keys_compare_op, stream, debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -610,9 +710,14 @@ TYPED_TEST(RocprimDeviceScanTests, InclusiveScanByKey) hipFree(d_input); hipFree(d_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } } - } TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) @@ -654,6 +759,11 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) break; } hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -713,6 +823,11 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) auto input_iterator = rocprim::make_transform_iterator( d_input, [] (T in) { return static_cast(in); }); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; @@ -724,6 +839,9 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) ) ); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -731,6 +849,9 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::exclusive_scan_by_key( @@ -738,6 +859,10 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) d_output, initial_value, input.size(), scan_op, keys_compare_op, stream, debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -759,6 +884,12 @@ TYPED_TEST(RocprimDeviceScanTests, ExclusiveScanByKey) hipFree(d_input); hipFree(d_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } } } @@ -830,7 +961,8 @@ class single_index_iterator { // clang-format on }; -TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScan) +template +void testLargeIndicesInclusiveScan() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); @@ -841,7 +973,12 @@ TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScan) using OutputIterator = single_index_iterator; const bool debug_synchronous = false; - const hipStream_t stream = 0; // default + hipStream_t stream = 0; // default + if (UseGraphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -869,15 +1006,23 @@ TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScan) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (UseGraphs) + graph = test_utils::createGraphHelper(stream); + // Get temporary array size HIP_CHECK( - rocprim::inclusive_scan( - d_temp_storage, temp_storage_size_bytes, - input_begin, output_it, size, - ::rocprim::plus(), - stream, debug_synchronous - ) - ); + rocprim::inclusive_scan( + d_temp_storage, temp_storage_size_bytes, + input_begin, output_it, size, + ::rocprim::plus(), + stream, debug_synchronous + ) + ); + + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -886,15 +1031,22 @@ TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScan) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (UseGraphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( - rocprim::inclusive_scan( - d_temp_storage, temp_storage_size_bytes, - input_begin, output_it, size, - ::rocprim::plus(), - stream, debug_synchronous - ) - ); + rocprim::inclusive_scan( + d_temp_storage, temp_storage_size_bytes, + input_begin, output_it, size, + ::rocprim::plus(), + stream, debug_synchronous + ) + ); + + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -907,17 +1059,34 @@ TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScan) const T multiplicand_1 = size; const T multiplicand_2 = 2 * (*input_begin) + size - 1; const T expected_output = (multiplicand_1 % 2 == 0) ? multiplicand_1 / 2 * multiplicand_2 - : multiplicand_1 * (multiplicand_2 / 2); + : multiplicand_1 * (multiplicand_2 / 2); ASSERT_EQ(output, expected_output); hipFree(d_temp_storage); hipFree(d_output); + + if (UseGraphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + + if (UseGraphs) + HIP_CHECK(hipStreamDestroy(stream)); } -TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScan) +TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScan) +{ + testLargeIndicesInclusiveScan(); +} + +TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScanWithGraphs) +{ + testLargeIndicesInclusiveScan(); +} + +template +void testLargeIndicesExclusiveScan() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); @@ -928,7 +1097,12 @@ TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScan) using OutputIterator = single_index_iterator; const bool debug_synchronous = false; - const hipStream_t stream = 0; // default + hipStream_t stream = 0; // default + if (UseGraphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -958,16 +1132,24 @@ TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScan) size_t temp_storage_size_bytes; void * d_temp_storage = nullptr; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (UseGraphs) + graph = test_utils::createGraphHelper(stream); + // Get temporary array size HIP_CHECK( - rocprim::exclusive_scan( - d_temp_storage, temp_storage_size_bytes, - input_begin, output_it, - initial_value, size, - ::rocprim::plus(), - stream, debug_synchronous - ) - ); + rocprim::exclusive_scan( + d_temp_storage, temp_storage_size_bytes, + input_begin, output_it, + initial_value, size, + ::rocprim::plus(), + stream, debug_synchronous + ) + ); + + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); @@ -976,16 +1158,23 @@ TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScan) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (UseGraphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( - rocprim::exclusive_scan( - d_temp_storage, temp_storage_size_bytes, - input_begin, output_it, - initial_value, size, - ::rocprim::plus(), - stream, debug_synchronous - ) - ); + rocprim::exclusive_scan( + d_temp_storage, temp_storage_size_bytes, + input_begin, output_it, + initial_value, size, + ::rocprim::plus(), + stream, debug_synchronous + ) + ); + + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -999,7 +1188,7 @@ TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScan) const T multiplicand_2 = 2 * (*input_begin) + size - 2; const T product = (multiplicand_1 % 2 == 0) ? multiplicand_1 / 2 * multiplicand_2 - : multiplicand_1 * (multiplicand_2 / 2); + : multiplicand_1 * (multiplicand_2 / 2); const T expected_output = initial_value + product; @@ -1007,8 +1196,24 @@ TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScan) hipFree(d_temp_storage); hipFree(d_output); + + if (UseGraphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + + if (UseGraphs) + HIP_CHECK(hipStreamDestroy(stream)); +} + +TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScan) +{ + testLargeIndicesExclusiveScan(); +} + +TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScanWithGraphs) +{ + testLargeIndicesExclusiveScan(); } /// \brief This iterator keeps track of the current index. Upon dereference, a \p CheckValue object @@ -1145,7 +1350,7 @@ using check_run_exclusive_iterator /// \p brief Provides a skeleton to both the inclusive and exclusive scan large indices tests. /// The call to the appropriate scan function must be implemented in \p scan_by_key_fun. -template +template void large_indices_scan_by_key_test(ScanByKeyFun scan_by_key_fun) { const int device_id = test_common_utils::obtain_device_from_ctest(); @@ -1153,7 +1358,12 @@ void large_indices_scan_by_key_test(ScanByKeyFun scan_by_key_fun) HIP_CHECK(hipSetDevice(device_id)); constexpr bool debug_synchronous = false; - constexpr hipStream_t stream = 0; + hipStream_t stream = 0; + if (UseGraphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const int seed_value = rand(); SCOPED_TRACE(testing::Message() << "with seed = " << seed_value); @@ -1173,6 +1383,11 @@ void large_indices_scan_by_key_test(ScanByKeyFun scan_by_key_fun) { return value / run_length; }); const auto values_input = rocprim::counting_iterator(0); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (UseGraphs) + graph = test_utils::createGraphHelper(stream); + size_t temp_storage_size_bytes; void* d_temp_storage = nullptr; HIP_CHECK(scan_by_key_fun(d_temp_storage, @@ -1185,8 +1400,15 @@ void large_indices_scan_by_key_test(ScanByKeyFun scan_by_key_fun) stream, debug_synchronous, seed_value)); + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temp_storage_size_bytes, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + + if (UseGraphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(scan_by_key_fun(d_temp_storage, temp_storage_size_bytes, keys_input, @@ -1197,6 +1419,10 @@ void large_indices_scan_by_key_test(ScanByKeyFun scan_by_key_fun) stream, debug_synchronous, seed_value)); + + if (UseGraphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipGetLastError()); unsigned int incorrect_flag; @@ -1209,9 +1435,16 @@ void large_indices_scan_by_key_test(ScanByKeyFun scan_by_key_fun) HIP_CHECK(hipFree(d_temp_storage)); HIP_CHECK(hipFree(d_incorrect_flag)); + + if (UseGraphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } -TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScanByKey) +template +void testLargeIndicesInclusiveScanByKey() { auto inclusive_scan_by_key = [](void* d_temp_storage, size_t& temp_storage_size_bytes, @@ -1225,7 +1458,7 @@ TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScanByKey) int /*seed_value*/) -> hipError_t { const check_run_inclusive_iterator output_it( - rocprim::make_tuple(run_length, d_incorrect_flag)); + rocprim::make_tuple(run_length, d_incorrect_flag)); return rocprim::inclusive_scan_by_key(d_temp_storage, temp_storage_size_bytes, @@ -1238,10 +1471,21 @@ TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScanByKey) stream, debug_synchronous); }; - large_indices_scan_by_key_test(inclusive_scan_by_key); + large_indices_scan_by_key_test(inclusive_scan_by_key); } -TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScanByKey) +TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScanByKey) +{ + testLargeIndicesInclusiveScanByKey(); +} + +TEST(RocprimDeviceScanTests, LargeIndicesInclusiveScanByKeyWithGraphs) +{ + testLargeIndicesInclusiveScanByKey(); +} + +template +void testLargeIndicesExclusiveScanByKey() { auto exclusive_scan_by_key = [](void* d_temp_storage, size_t& temp_storage_size_bytes, @@ -1256,7 +1500,7 @@ TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScanByKey) { const size_t initial_value = test_utils::get_random_value(0, 10000, seed_value); const check_run_exclusive_iterator output_it( - rocprim::make_tuple(run_length, initial_value, d_incorrect_flag)); + rocprim::make_tuple(run_length, initial_value, d_incorrect_flag)); return rocprim::exclusive_scan_by_key(d_temp_storage, temp_storage_size_bytes, keys_input, @@ -1269,7 +1513,17 @@ TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScanByKey) stream, debug_synchronous); }; - large_indices_scan_by_key_test(exclusive_scan_by_key); + large_indices_scan_by_key_test(exclusive_scan_by_key); +} + +TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScanByKey) +{ + testLargeIndicesExclusiveScanByKey(); +} + +TEST(RocprimDeviceScanTests, LargeIndicesExclusiveScanByKeyWithGraphs) +{ + testLargeIndicesExclusiveScanByKey(); } using RocprimDeviceScanFutureTestsParams @@ -1278,7 +1532,8 @@ using RocprimDeviceScanFutureTestsParams DeviceScanParams>, DeviceScanParams, true>, DeviceScanParams>, - DeviceScanParams>>; + DeviceScanParams>, + DeviceScanParams, false, default_config_helper, true>>; template class RocprimDeviceScanFutureTests : public RocprimDeviceScanTests @@ -1326,7 +1581,13 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) << std::endl; break; } - const hipStream_t stream = 0; // default + + hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -1372,6 +1633,11 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) auto input_iterator = rocprim::make_transform_iterator( d_input, [] (T in) { return static_cast(in); }); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; char* d_temp_storage = nullptr; @@ -1385,24 +1651,38 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) stream, debug_synchronous)); + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + // temp_storage_size_bytes must be >0 ASSERT_GT(temp_storage_size_bytes, 0); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + size_t temp_storage_reduce = 0; HIP_CHECK(rocprim::reduce( - nullptr, temp_storage_reduce, d_future_input, d_initial_value, 2048)); + nullptr, temp_storage_reduce, d_future_input, d_initial_value, 2048, rocprim::plus(), stream)); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); // allocate temporary storage HIP_CHECK(test_common_utils::hipMallocHelper( &d_temp_storage, temp_storage_size_bytes + temp_storage_reduce)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Fill initial value on the device HIP_CHECK(rocprim::reduce(d_temp_storage + temp_storage_size_bytes, temp_storage_reduce, d_future_input, d_initial_value, - 2048)); + 2048, + rocprim::plus(), + stream)); // Run HIP_CHECK(rocprim::exclusive_scan( @@ -1414,7 +1694,9 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) stream, debug_synchronous)); HIP_CHECK(hipGetLastError()); - HIP_CHECK(hipDeviceSynchronize()); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); // Copy output to host HIP_CHECK(hipMemcpy( @@ -1429,6 +1711,12 @@ TYPED_TEST(RocprimDeviceScanFutureTests, ExclusiveScan) hipFree(d_future_input); hipFree(d_initial_value); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } } } diff --git a/test/rocprim/test_device_segmented_reduce.cpp b/test/rocprim/test_device_segmented_reduce.cpp index 4712b6dac..2f2404435 100644 --- a/test/rocprim/test_device_segmented_reduce.cpp +++ b/test/rocprim/test_device_segmented_reduce.cpp @@ -39,7 +39,8 @@ template + bool UseDefaultConfig = false, + bool UseGraphs = false> struct SegmentedReduceParams { using input_type = Input; @@ -51,6 +52,7 @@ struct SegmentedReduceParams static constexpr bool use_identity_iterator = UseIdentityIterator; static constexpr bra algo = Algo; static constexpr bool use_default_config = UseDefaultConfig; + static constexpr bool use_graphs = UseGraphs; }; // clang-format off @@ -118,7 +120,9 @@ typedef ::testing::Types< SegmentedReduceParamsList(unsigned char, unsigned int, plus, 0, 0, 1000, false), SegmentedReduceParamsList(unsigned char, long long, plus, 10, 3000, 4000, true), SegmentedReduceParamsList(half, float, plus, 0, 10, 300, false), - SegmentedReduceParamsList(bfloat16, float, plus, 0, 10, 300, false)> + SegmentedReduceParamsList(bfloat16, float, plus, 0, 10, 300, false), + // Test with graphs + SegmentedReduceParams, 0, 0, 1000, false, bra::default_algorithm, false, true>> Params; #undef plus @@ -165,6 +169,11 @@ TYPED_TEST(RocprimDeviceSegmentedReduce, Reduce) SCOPED_TRACE(testing::Message() << "with size = " << size); hipStream_t stream = 0; // default + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } // Generate data and calculate expected results std::vector aggregates_expected; @@ -232,6 +241,11 @@ TYPED_TEST(RocprimDeviceSegmentedReduce, Reduce) HIP_CHECK(test_common_utils::hipMallocHelper(&d_aggregates_output, segments_count * sizeof(output_type))); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temporary_storage_bytes; HIP_CHECK(rocprim::segmented_reduce(nullptr, @@ -246,12 +260,18 @@ TYPED_TEST(RocprimDeviceSegmentedReduce, Reduce) stream, debug_synchronous)); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); void* d_temporary_storage; HIP_CHECK( test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::segmented_reduce( d_temporary_storage, temporary_storage_bytes, @@ -265,6 +285,9 @@ TYPED_TEST(RocprimDeviceSegmentedReduce, Reduce) stream, debug_synchronous)); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipFree(d_temporary_storage)); std::vector aggregates_output(segments_count); @@ -277,6 +300,12 @@ TYPED_TEST(RocprimDeviceSegmentedReduce, Reduce) HIP_CHECK(hipFree(d_offsets)); HIP_CHECK(hipFree(d_aggregates_output)); + if (TestFixture::params::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } + ASSERT_NO_FATAL_FAILURE( test_utils::assert_near(aggregates_output, aggregates_expected, precision)); } diff --git a/test/rocprim/test_device_segmented_scan.cpp b/test/rocprim/test_device_segmented_scan.cpp index 8fc62ffb5..f8a06c247 100644 --- a/test/rocprim/test_device_segmented_scan.cpp +++ b/test/rocprim/test_device_segmented_scan.cpp @@ -40,7 +40,8 @@ template< // Tests output iterator with void value_type (OutputIterator concept) // Segmented scan primitives which use head flags do not support this kind // of output iterators. - bool UseIdentityIterator = false + bool UseIdentityIterator = false, + bool UseGraphs = false > struct params { @@ -51,6 +52,7 @@ struct params static constexpr unsigned int min_segment_length = MinSegmentLength; static constexpr unsigned int max_segment_length = MaxSegmentLength; static constexpr bool use_identity_iterator = UseIdentityIterator; + static constexpr bool use_graphs = UseGraphs; }; template @@ -76,7 +78,8 @@ typedef ::testing::Types< params, 0, 1000, 30000>, params, 0, 10, 200, true>, params, 0, 1000, 30000>, - params, 10, 3000, 4000>> + params, 10, 3000, 4000>, + params, 0, 0, 1000, false, true>> Params; TYPED_TEST_SUITE(RocprimDeviceSegmentedScan, Params); @@ -107,6 +110,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScan) ); hipStream_t stream = 0; // default stream + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -185,6 +193,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScan) ); HIP_CHECK(hipDeviceSynchronize()); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temporary_storage_bytes; HIP_CHECK( rocprim::segmented_inclusive_scan( @@ -198,10 +211,16 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScan) ) ); + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temporary_storage_bytes, 0); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK( rocprim::segmented_inclusive_scan( d_temporary_storage, temporary_storage_bytes, @@ -213,6 +232,10 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScan) stream, debug_synchronous ) ); + + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); std::vector values_output(size); @@ -232,9 +255,14 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScan) HIP_CHECK(hipFree(d_values_input)); HIP_CHECK(hipFree(d_offsets)); HIP_CHECK(hipFree(d_values_output)); + + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScan) @@ -265,6 +293,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScan) ); hipStream_t stream = 0; // default stream + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -343,6 +376,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScan) ); HIP_CHECK(hipDeviceSynchronize()); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temporary_storage_bytes; HIP_CHECK( rocprim::segmented_exclusive_scan( @@ -355,12 +393,19 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScan) stream, debug_synchronous ) ); + + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipDeviceSynchronize()); ASSERT_GT(temporary_storage_bytes, 0); void * d_temporary_storage; HIP_CHECK(test_common_utils::hipMallocHelper(&d_temporary_storage, temporary_storage_bytes)); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK( rocprim::segmented_exclusive_scan( d_temporary_storage, temporary_storage_bytes, @@ -372,6 +417,10 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScan) stream, debug_synchronous ) ); + + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); std::vector values_output(size); @@ -391,9 +440,14 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScan) HIP_CHECK(hipFree(d_values_input)); HIP_CHECK(hipFree(d_offsets)); HIP_CHECK(hipFree(d_values_output)); + + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) @@ -415,6 +469,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) scan_op_type scan_op; hipStream_t stream = 0; // default stream + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -501,6 +560,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) expected.begin(), scan_op); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -512,6 +576,10 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) debug_synchronous ) ); + + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -522,6 +590,9 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::segmented_inclusive_scan( @@ -531,6 +602,10 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) debug_synchronous ) ); + + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // Check if output values are as expected @@ -550,9 +625,14 @@ TYPED_TEST(RocprimDeviceSegmentedScan, InclusiveScanUsingHeadFlags) HIP_CHECK(hipFree(d_input)); HIP_CHECK(hipFree(d_flags)); HIP_CHECK(hipFree(d_output)); + + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) @@ -576,6 +656,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) scan_op_type scan_op; hipStream_t stream = 0; // default stream + if (TestFixture::params::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -663,6 +748,11 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) scan_op, init); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::params::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -673,6 +763,10 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) input.size(), scan_op, stream, debug_synchronous ) ); + + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -683,6 +777,9 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::params::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::segmented_exclusive_scan( @@ -691,6 +788,10 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) input.size(), scan_op, stream, debug_synchronous ) ); + + if (TestFixture::params::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // Check if output values are as expected @@ -702,6 +803,10 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) hipMemcpyDeviceToHost ) ); + + if (TestFixture::params::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipDeviceSynchronize()); ASSERT_NO_FATAL_FAILURE(test_utils::assert_near(output, expected, precision)); @@ -713,4 +818,6 @@ TYPED_TEST(RocprimDeviceSegmentedScan, ExclusiveScanUsingHeadFlags) } } + if (TestFixture::params::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } diff --git a/test/rocprim/test_device_select.cpp b/test/rocprim/test_device_select.cpp index b1f36bd0a..b903d5b4d 100644 --- a/test/rocprim/test_device_select.cpp +++ b/test/rocprim/test_device_select.cpp @@ -37,7 +37,8 @@ template< class InputType, class OutputType = InputType, class FlagType = unsigned int, - bool UseIdentityIterator = false + bool UseIdentityIterator = false, + bool UseGraphs = false > struct DeviceSelectParams { @@ -45,6 +46,7 @@ struct DeviceSelectParams using output_type = OutputType; using flag_type = FlagType; static constexpr bool use_identity_iterator = UseIdentityIterator; + static constexpr bool use_graphs = UseGraphs; }; template @@ -56,6 +58,7 @@ class RocprimDeviceSelectTests : public ::testing::Test using flag_type = typename Params::flag_type; const bool debug_synchronous = false; static constexpr bool use_identity_iterator = Params::use_identity_iterator; + static constexpr bool use_graphs = Params::use_graphs; }; typedef ::testing::Types< @@ -66,7 +69,8 @@ typedef ::testing::Types< DeviceSelectParams, DeviceSelectParams, DeviceSelectParams, - DeviceSelectParams, test_utils::custom_test_type, int, true> + DeviceSelectParams, test_utils::custom_test_type, int, true>, + DeviceSelectParams > RocprimDeviceSelectTestsParams; TYPED_TEST_SUITE(RocprimDeviceSelectTests, RocprimDeviceSelectTestsParams); @@ -81,9 +85,13 @@ TYPED_TEST(RocprimDeviceSelectTests, Flagged) using U = typename TestFixture::output_type; using F = typename TestFixture::flag_type; static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; - const bool debug_synchronous = TestFixture::debug_synchronous; hipStream_t stream = 0; // default stream + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -133,6 +141,11 @@ TYPED_TEST(RocprimDeviceSelectTests, Flagged) } } + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -146,9 +159,13 @@ TYPED_TEST(RocprimDeviceSelectTests, Flagged) d_selected_count_output, input.size(), stream, - debug_synchronous + TestFixture::debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -159,6 +176,9 @@ TYPED_TEST(RocprimDeviceSelectTests, Flagged) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::select( @@ -170,9 +190,13 @@ TYPED_TEST(RocprimDeviceSelectTests, Flagged) d_selected_count_output, input.size(), stream, - debug_synchronous + TestFixture::debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected @@ -204,11 +228,17 @@ TYPED_TEST(RocprimDeviceSelectTests, Flagged) hipFree(d_output); hipFree(d_selected_count_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + if (TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } + template struct select_op { @@ -231,6 +261,11 @@ TYPED_TEST(RocprimDeviceSelectTests, SelectOp) const bool debug_synchronous = TestFixture::debug_synchronous; hipStream_t stream = 0; // default stream + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -270,6 +305,11 @@ TYPED_TEST(RocprimDeviceSelectTests, SelectOp) } } + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -286,6 +326,10 @@ TYPED_TEST(RocprimDeviceSelectTests, SelectOp) debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -296,6 +340,9 @@ TYPED_TEST(RocprimDeviceSelectTests, SelectOp) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::select( @@ -310,6 +357,10 @@ TYPED_TEST(RocprimDeviceSelectTests, SelectOp) debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected @@ -340,9 +391,14 @@ TYPED_TEST(RocprimDeviceSelectTests, SelectOp) hipFree(d_output); hipFree(d_selected_count_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + if (TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } std::vector get_discontinuity_probabilities() @@ -369,6 +425,11 @@ TYPED_TEST(RocprimDeviceSelectTests, Unique) const bool debug_synchronous = TestFixture::debug_synchronous; hipStream_t stream = 0; // default stream + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -423,6 +484,11 @@ TYPED_TEST(RocprimDeviceSelectTests, Unique) } } + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -439,6 +505,10 @@ TYPED_TEST(RocprimDeviceSelectTests, Unique) debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -449,6 +519,9 @@ TYPED_TEST(RocprimDeviceSelectTests, Unique) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::unique( @@ -463,6 +536,10 @@ TYPED_TEST(RocprimDeviceSelectTests, Unique) debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected @@ -493,9 +570,15 @@ TYPED_TEST(RocprimDeviceSelectTests, Unique) hipFree(d_output); hipFree(d_selected_count_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } } + + if (TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } // The operator must be only called, when we have valid element in a block @@ -520,7 +603,8 @@ struct element_equal_operator } }; -TEST(RocprimDeviceSelectTests, UniqueGuardedOperator) +template +void testUniqueGuardedOperator() { int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); @@ -534,6 +618,11 @@ TEST(RocprimDeviceSelectTests, UniqueGuardedOperator) const bool debug_synchronous = false; hipStream_t stream = 0; // default stream + if (UseGraphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -601,6 +690,11 @@ TEST(RocprimDeviceSelectTests, UniqueGuardedOperator) } } + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (UseGraphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -617,6 +711,10 @@ TEST(RocprimDeviceSelectTests, UniqueGuardedOperator) debug_synchronous ) ); + + if (UseGraphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -627,6 +725,9 @@ TEST(RocprimDeviceSelectTests, UniqueGuardedOperator) HIP_CHECK(hipMalloc(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (UseGraphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::unique( @@ -641,6 +742,10 @@ TEST(RocprimDeviceSelectTests, UniqueGuardedOperator) debug_synchronous ) ); + + if (UseGraphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected @@ -672,9 +777,25 @@ TEST(RocprimDeviceSelectTests, UniqueGuardedOperator) hipFree(d_output); hipFree(d_selected_count_output); hipFree(d_temp_storage); + + if (UseGraphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } } + + if (UseGraphs) + HIP_CHECK(hipStreamDestroy(stream)); +} + +TEST(RocprimDeviceSelectTests, UniqueGuardedOperator) +{ + testUniqueGuardedOperator(); +} + +TEST(RocprimDeviceSelectTests, UniqueGuardedOperatorWithGraphs) +{ + testUniqueGuardedOperator(); } // Params for tests @@ -683,7 +804,8 @@ template< typename ValueType, typename OutputKeyType = KeyType, typename OutputValueType = ValueType, - bool UseIdentityIterator = false + bool UseIdentityIterator = false, + bool UseGraphs = false > struct DeviceUniqueByKeyParams { @@ -692,6 +814,7 @@ struct DeviceUniqueByKeyParams using output_key_type = OutputKeyType; using output_value_type = OutputValueType; static constexpr bool use_identity_iterator = UseIdentityIterator; + static constexpr bool use_graphs = UseGraphs; }; template @@ -704,6 +827,7 @@ class RocprimDeviceUniqueByKeyTests : public ::testing::Test using output_value_type = typename Params::output_value_type; const bool debug_synchronous = false; static constexpr bool use_identity_iterator = Params::use_identity_iterator; + const bool use_graphs = Params::use_graphs; }; typedef ::testing::Types< @@ -714,7 +838,8 @@ typedef ::testing::Types< DeviceUniqueByKeyParams, DeviceUniqueByKeyParams, DeviceUniqueByKeyParams, - DeviceUniqueByKeyParams, test_utils::custom_test_type> + DeviceUniqueByKeyParams, test_utils::custom_test_type>, + DeviceUniqueByKeyParams > RocprimDeviceUniqueByKeyTestParams; TYPED_TEST_SUITE(RocprimDeviceUniqueByKeyTests, RocprimDeviceUniqueByKeyTestParams); @@ -737,6 +862,11 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) const bool debug_synchronous = TestFixture::debug_synchronous; hipStream_t stream = 0; // default stream + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -808,6 +938,11 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) } } + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // temp storage size_t temp_storage_size_bytes; // Get size of d_temp_storage @@ -826,6 +961,10 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -836,6 +975,9 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (TestFixture::use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::unique_by_key( @@ -852,6 +994,10 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected @@ -893,18 +1039,28 @@ TYPED_TEST(RocprimDeviceUniqueByKeyTests, UniqueByKey) hipFree(d_values_output); hipFree(d_selected_count_output); hipFree(d_temp_storage); + + if (TestFixture::use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } } + + if (TestFixture::use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } -class RocprimDeviceSelectLargeInputTests : public ::testing::TestWithParam { +class RocprimDeviceSelectLargeInputTests : public ::testing::TestWithParam> { public: const bool debug_synchronous = false; }; -INSTANTIATE_TEST_SUITE_P(RocprimDeviceSelectLargeInputFlaggedTest, RocprimDeviceSelectLargeInputTests, ::testing::Values( - 2048, 9643, 32768, 38713 +INSTANTIATE_TEST_SUITE_P(RocprimDeviceSelectLargeInputFlaggedTest, RocprimDeviceSelectLargeInputTests, + ::testing::Values(std::make_pair(2048, false), // params: flag_selector/segment_length, use_graphs + std::make_pair(9643, false), + std::make_pair(32768, false), + std::make_pair(38713, false), + std::make_pair(38713, true) )); TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) @@ -913,13 +1069,20 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - unsigned int flag_selector = GetParam(); + auto param = GetParam(); + unsigned int flag_selector = std::get<0>(param); + const bool use_graphs = std::get<1>(param); using InputIterator = typename rocprim::counting_iterator; const bool debug_synchronous = RocprimDeviceSelectLargeInputTests::debug_synchronous; hipStream_t stream = 0; // default stream + if (use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for(auto size : test_utils::get_large_sizes(0)) { @@ -960,6 +1123,11 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) size_t temp_storage_size_bytes; void *d_temp_storage = nullptr; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (use_graphs) + graph = test_utils::createGraphHelper(stream); + // Get size of d_temp_storage HIP_CHECK( rocprim::select( @@ -974,6 +1142,10 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) debug_synchronous ) ); + + if (use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // temp_storage_size_bytes must be >0 @@ -983,6 +1155,9 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); HIP_CHECK(hipDeviceSynchronize()); + if (use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + // Run HIP_CHECK( rocprim::select( @@ -997,6 +1172,10 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) debug_synchronous ) ); + + if (use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipDeviceSynchronize()); // Check if number of selected value is as expected @@ -1025,20 +1204,34 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputFlagged) hipFree(d_output); hipFree(d_selected_count_output); hipFree(d_temp_storage); + + if (use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } + + if (use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputUnique) { static constexpr bool debug_synchronous = false; - static constexpr hipStream_t stream = 0; + + auto param = GetParam(); + const unsigned int segment_length = std::get<0>(param); + const bool use_graphs = std::get<1>(param); + + hipStream_t stream = 0; + if (use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } const int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); HIP_CHECK(hipSetDevice(device_id)); - const unsigned int segment_length = GetParam(); - for(const auto size : test_utils::get_large_sizes(0)) { // otherwise test is too long @@ -1061,6 +1254,11 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputUnique) HIP_CHECK(test_common_utils::hipMallocHelper(&d_unique_count_output, sizeof(*d_unique_count_output))); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (use_graphs) + graph = test_utils::createGraphHelper(stream); + size_t temp_storage_size_bytes{}; void* d_temp_storage{}; HIP_CHECK(rocprim::unique(d_temp_storage, @@ -1072,9 +1270,16 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputUnique) rocprim::equal_to{}, stream, debug_synchronous)); + + if (use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + ASSERT_GT(temp_storage_size_bytes, 0); HIP_CHECK(test_common_utils::hipMallocHelper(&d_temp_storage, temp_storage_size_bytes)); + if (use_graphs) + test_utils::resetGraphHelper(graph, graph_instance, stream); + HIP_CHECK(rocprim::unique(d_temp_storage, temp_storage_size_bytes, input_it, @@ -1085,6 +1290,9 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputUnique) stream, debug_synchronous)); + if (use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, true); + size_t unique_count_output{}; HIP_CHECK(hipMemcpyWithStream(&unique_count_output, d_unique_count_output, @@ -1106,5 +1314,11 @@ TEST_P(RocprimDeviceSelectLargeInputTests, LargeInputUnique) HIP_CHECK(hipFree(d_output)); HIP_CHECK(hipFree(d_unique_count_output)); HIP_CHECK(hipFree(d_temp_storage)); + + if (use_graphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } + + if (use_graphs) + HIP_CHECK(hipStreamDestroy(stream)); } diff --git a/test/rocprim/test_device_transform.cpp b/test/rocprim/test_device_transform.cpp index b8ae28f3d..7fee5604e 100644 --- a/test/rocprim/test_device_transform.cpp +++ b/test/rocprim/test_device_transform.cpp @@ -35,7 +35,8 @@ template< class InputType, class OutputType = InputType, bool UseIdentityIterator = false, - unsigned int SizeLimit = ROCPRIM_GRID_SIZE_LIMIT + unsigned int SizeLimit = ROCPRIM_GRID_SIZE_LIMIT, + bool UseGraphs = false > struct DeviceTransformParams { @@ -43,6 +44,7 @@ struct DeviceTransformParams using output_type = OutputType; static constexpr bool use_identity_iterator = UseIdentityIterator; static constexpr size_t size_limit = SizeLimit; + static constexpr bool use_graphs = UseGraphs; }; // --------------------------------------------------------- @@ -58,6 +60,7 @@ class RocprimDeviceTransformTests : public ::testing::Test static constexpr bool use_identity_iterator = Params::use_identity_iterator; static constexpr bool debug_synchronous = false; static constexpr size_t size_limit = Params::size_limit; + static constexpr bool use_graphs = Params::use_graphs; }; using custom_short2 = test_utils::custom_test_type; @@ -79,7 +82,8 @@ typedef ::testing::Types< DeviceTransformParams, DeviceTransformParams, DeviceTransformParams, - DeviceTransformParams + DeviceTransformParams, + DeviceTransformParams > RocprimDeviceTransformTestsParams; template @@ -116,7 +120,6 @@ TYPED_TEST(RocprimDeviceTransformTests, Transform) using T = typename TestFixture::input_type; using U = typename TestFixture::output_type; static constexpr bool use_identity_iterator = TestFixture::use_identity_iterator; - const bool debug_synchronous = TestFixture::debug_synchronous; using Config = size_limit_config_t; for (size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) @@ -127,6 +130,11 @@ TYPED_TEST(RocprimDeviceTransformTests, Transform) for(auto size : test_utils::get_sizes(seed_value)) { hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -151,14 +159,23 @@ TYPED_TEST(RocprimDeviceTransformTests, Transform) std::vector expected(input.size()); std::transform(input.begin(), input.end(), expected.begin(), transform()); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // Run HIP_CHECK( rocprim::transform( d_input, test_utils::wrap_in_identity_iterator(d_output), - input.size(), transform(), stream, debug_synchronous + input.size(), transform(), stream, TestFixture::debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -178,9 +195,14 @@ TYPED_TEST(RocprimDeviceTransformTests, Transform) hipFree(d_input); hipFree(d_output); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } } - } template @@ -214,6 +236,11 @@ TYPED_TEST(RocprimDeviceTransformTests, BinaryTransform) for(auto size : test_utils::get_sizes(seed_value)) { hipStream_t stream = 0; // default + if (TestFixture::use_graphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } SCOPED_TRACE(testing::Message() << "with size = " << size); @@ -251,6 +278,11 @@ TYPED_TEST(RocprimDeviceTransformTests, BinaryTransform) expected.begin(), binary_transform() ); + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (TestFixture::use_graphs) + graph = test_utils::createGraphHelper(stream); + // Run HIP_CHECK( rocprim::transform( @@ -259,6 +291,10 @@ TYPED_TEST(RocprimDeviceTransformTests, BinaryTransform) input1.size(), binary_transform(), stream, debug_synchronous ) ); + + if (TestFixture::use_graphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -279,12 +315,18 @@ TYPED_TEST(RocprimDeviceTransformTests, BinaryTransform) hipFree(d_input1); hipFree(d_input2); hipFree(d_output); + + if (TestFixture::use_graphs) + { + test_utils::cleanupGraphHelper(graph, graph_instance); + HIP_CHECK(hipStreamDestroy(stream)); + } } } - } -TEST(RocprimDeviceTransformTests, LargeIndices) +template +void testLargeIndices() { const int device_id = test_common_utils::obtain_device_from_ctest(); SCOPED_TRACE(testing::Message() << "with device_id = " << device_id); @@ -295,7 +337,12 @@ TEST(RocprimDeviceTransformTests, LargeIndices) using OutputIterator = rocprim::discard_iterator; const bool debug_synchronous = false; - const hipStream_t stream = 0; // default + hipStream_t stream = 0; // default + if (UseGraphs) + { + // Default stream does not support hipGraph stream capture, so create one + HIP_CHECK(hipStreamCreateWithFlags(&stream, hipStreamNonBlocking)); + } for(size_t seed_index = 0; seed_index < random_seeds_count + seed_size; seed_index++) { @@ -336,9 +383,18 @@ TEST(RocprimDeviceTransformTests, LargeIndices) return 0; }; + hipGraph_t graph; + hipGraphExec_t graph_instance; + if (UseGraphs) + graph = test_utils::createGraphHelper(stream); + // Run HIP_CHECK( rocprim::transform(input, output, size, flag_expected, stream, debug_synchronous)); + + if (UseGraphs) + graph_instance = graph_instance = test_utils::endCaptureGraphHelper(graph, stream, true, false); + HIP_CHECK(hipGetLastError()); HIP_CHECK(hipDeviceSynchronize()); @@ -348,8 +404,24 @@ TEST(RocprimDeviceTransformTests, LargeIndices) ASSERT_TRUE(flags[0]); ASSERT_TRUE(flags[1]); + + HIP_CHECK(hipFree(d_flag)); - hipFree(d_flag); + if (UseGraphs) + test_utils::cleanupGraphHelper(graph, graph_instance); } } + + if (UseGraphs) + HIP_CHECK(hipStreamDestroy(stream)); +} + +TEST(RocprimDeviceTransformTests, LargeIndices) +{ + testLargeIndices(); +} + +TEST(RocprimDeviceTransformTests, LargeIndicesWithGraphs) +{ + testLargeIndices(); } diff --git a/test/rocprim/test_utils.hpp b/test/rocprim/test_utils.hpp index 573b5cbb9..13dea15b6 100644 --- a/test/rocprim/test_utils.hpp +++ b/test/rocprim/test_utils.hpp @@ -40,6 +40,7 @@ #include "test_utils_custom_test_types.hpp" #include "test_utils_data_generation.hpp" #include "test_utils_assertions.hpp" +#include "test_utils_hipgraphs.hpp" // Helper macros to disable warnings in clang #ifdef __clang__ diff --git a/test/rocprim/test_utils_hipgraphs.hpp b/test/rocprim/test_utils_hipgraphs.hpp new file mode 100644 index 000000000..a108256ba --- /dev/null +++ b/test/rocprim/test_utils_hipgraphs.hpp @@ -0,0 +1,88 @@ +// Copyright (c) 2021-2023 Advanced Micro Devices, Inc. All rights reserved. +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +// THE SOFTWARE. + +#ifndef ROCPRIM_TEST_UTILS_HIPGRAPHS_HPP +#define ROCPRIM_TEST_UTILS_HIPGRAPHS_HPP + +// Helper functions for testing with hipGraph stream capture. +// Note: graphs will not work on the default stream. +namespace test_utils +{ + +inline hipGraph_t createGraphHelper(hipStream_t& stream, const bool beginCapture=true) +{ + // Create a new graph + hipGraph_t graph; + HIP_CHECK(hipGraphCreate(&graph, 0)); + + // Optionally begin stream capture + if (beginCapture) + HIP_CHECK(hipStreamBeginCapture(stream, hipStreamCaptureModeGlobal)); + + return graph; +} + +inline void cleanupGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance) +{ + HIP_CHECK(hipGraphDestroy(graph)); + HIP_CHECK(hipGraphExecDestroy(instance)); +} + +inline void resetGraphHelper(hipGraph_t& graph, hipGraphExec_t& instance, hipStream_t& stream, const bool beginCapture=true) +{ + // Destroy the old graph and instance + cleanupGraphHelper(graph, instance); + + // Create a new graph and optionally begin capture + graph = createGraphHelper(stream, beginCapture); +} + +inline hipGraphExec_t endCaptureGraphHelper(hipGraph_t& graph, hipStream_t& stream, const bool launchGraph=false, const bool sync=false) +{ + // End the capture + HIP_CHECK(hipStreamEndCapture(stream, &graph)); + + // Instantiate the graph + hipGraphExec_t instance; + HIP_CHECK(hipGraphInstantiate(&instance, graph, nullptr, nullptr, 0)); + + // Optionally launch the graph + if (launchGraph) + HIP_CHECK(hipGraphLaunch(instance, stream)); + + // Optionally synchronize the stream when we're done + if (sync) + HIP_CHECK(hipStreamSynchronize(stream)); + + return instance; +} + +inline void launchGraphHelper(hipGraphExec_t& instance, hipStream_t& stream, const bool sync=false) +{ + HIP_CHECK(hipGraphLaunch(instance, stream)); + + // Optionally sync after the launch + if (sync) + HIP_CHECK(hipStreamSynchronize(stream)); +} + +} // end namespace test_utils + +#endif //ROCPRIM_TEST_UTILS_HIPGRAPHS_HPP