From 811ba03ecab8fe7dcf9aca7dbb50b437e0075da2 Mon Sep 17 00:00:00 2001 From: Ashwin Srinath Date: Wed, 22 Jan 2025 17:00:17 -0500 Subject: [PATCH] Prep --- c/parallel/src/scan.cu | 8 +++++++- cub/cub/device/dispatch/dispatch_scan.cuh | 19 ++++++++++++++----- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/c/parallel/src/scan.cu b/c/parallel/src/scan.cu index 04501dd1a77..5628381e3fd 100644 --- a/c/parallel/src/scan.cu +++ b/c/parallel/src/scan.cu @@ -367,7 +367,13 @@ struct scan_kernel_source return scan_tile_state(build.accumulator_type); } - cudaError_t GetTileStateAllocationSize(int num_tiles, size_t& temp_storage_bytes) + cudaError_t TileStateInit(scan_tile_state& tile_state, int num_tiles, void* d_temp_storage, size_t temp_storage_bytes) + { + return tile_state.Init(num_tiles, d_temp_storage, temp_storage_bytes); + } + + cudaError_t + GetTileStateAllocationSize(scan_tile_state const& /*tile_state*/, int num_tiles, size_t& temp_storage_bytes) { return scan_tile_state::AllocationSize(num_tiles, temp_storage_bytes, GetWordSize()); } diff --git a/cub/cub/device/dispatch/dispatch_scan.cuh b/cub/cub/device/dispatch/dispatch_scan.cuh index 6746d2a43dd..983b6019dbb 100644 --- a/cub/cub/device/dispatch/dispatch_scan.cuh +++ b/cub/cub/device/dispatch/dispatch_scan.cuh @@ -104,9 +104,16 @@ struct DeviceScanKernelSource return ScanTileStateT(); } - CUB_RUNTIME_FUNCTION cudaError_t GetTileStateAllocationSize(int num_tiles, size_t& temp_storage_bytes) + CUB_RUNTIME_FUNCTION cudaError_t + TileStateInit(ScanTileStateT& tile_state, int num_tiles, void* d_temp_storage, size_t temp_storage_bytes) { - return ScanTileStateT::AllocationSize(num_tiles, temp_storage_bytes); + return tile_state.Init(num_tiles, d_temp_storage, temp_storage_bytes); + } + + CUB_RUNTIME_FUNCTION cudaError_t + GetTileStateAllocationSize(ScanTileStateT const& tile_state, int num_tiles, size_t& temp_storage_bytes) + { + return decltype(tile_state)::AllocationSize(num_tiles, temp_storage_bytes); } }; @@ -283,9 +290,12 @@ struct DispatchScan int tile_size = policy.Scan().BlockThreads() * policy.Scan().ItemsPerThread(); int num_tiles = static_cast(::cuda::ceil_div(num_items, tile_size)); + auto tile_state = kernel_source.TileState(); + // Specify temporary storage allocation requirements size_t allocation_sizes[1]; - error = CubDebug(kernel_source.GetTileStateAllocationSize(num_tiles, allocation_sizes[0])); + + error = CubDebug(kernel_source.GetTileStateAllocationSize(tile_state, num_tiles, allocation_sizes[0])); if (cudaSuccess != error) { break; // bytes needed for tile status descriptors @@ -315,8 +325,7 @@ struct DispatchScan } // Construct the tile status interface - auto tile_state = kernel_source.TileState(); - error = CubDebug(tile_state.Init(num_tiles, allocations[0], allocation_sizes[0])); + error = CubDebug(kernel_source.TileStateInit(tile_state, num_tiles, allocations[0], allocation_sizes[0])); if (cudaSuccess != error) { break;