Skip to content

Commit

Permalink
Prep
Browse files Browse the repository at this point in the history
  • Loading branch information
shwina committed Jan 22, 2025
1 parent 1048b9f commit 811ba03
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 6 deletions.
8 changes: 7 additions & 1 deletion c/parallel/src/scan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -367,7 +367,13 @@ struct scan_kernel_source
return scan_tile_state(build.accumulator_type);
}

cudaError_t GetTileStateAllocationSize(int num_tiles, size_t& temp_storage_bytes)
cudaError_t TileStateInit(scan_tile_state& tile_state, int num_tiles, void* d_temp_storage, size_t temp_storage_bytes)
{
return tile_state.Init(num_tiles, d_temp_storage, temp_storage_bytes);
}

cudaError_t
GetTileStateAllocationSize(scan_tile_state const& /*tile_state*/, int num_tiles, size_t& temp_storage_bytes)
{
return scan_tile_state::AllocationSize(num_tiles, temp_storage_bytes, GetWordSize());
}
Expand Down
19 changes: 14 additions & 5 deletions cub/cub/device/dispatch/dispatch_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,16 @@ struct DeviceScanKernelSource
return ScanTileStateT();
}

CUB_RUNTIME_FUNCTION cudaError_t GetTileStateAllocationSize(int num_tiles, size_t& temp_storage_bytes)
CUB_RUNTIME_FUNCTION cudaError_t
TileStateInit(ScanTileStateT& tile_state, int num_tiles, void* d_temp_storage, size_t temp_storage_bytes)
{
return ScanTileStateT::AllocationSize(num_tiles, temp_storage_bytes);
return tile_state.Init(num_tiles, d_temp_storage, temp_storage_bytes);
}

CUB_RUNTIME_FUNCTION cudaError_t
GetTileStateAllocationSize(ScanTileStateT const& tile_state, int num_tiles, size_t& temp_storage_bytes)
{
return decltype(tile_state)::AllocationSize(num_tiles, temp_storage_bytes);
}
};

Expand Down Expand Up @@ -283,9 +290,12 @@ struct DispatchScan
int tile_size = policy.Scan().BlockThreads() * policy.Scan().ItemsPerThread();
int num_tiles = static_cast<int>(::cuda::ceil_div(num_items, tile_size));

auto tile_state = kernel_source.TileState();

// Specify temporary storage allocation requirements
size_t allocation_sizes[1];
error = CubDebug(kernel_source.GetTileStateAllocationSize(num_tiles, allocation_sizes[0]));

error = CubDebug(kernel_source.GetTileStateAllocationSize(tile_state, num_tiles, allocation_sizes[0]));
if (cudaSuccess != error)
{
break; // bytes needed for tile status descriptors
Expand Down Expand Up @@ -315,8 +325,7 @@ struct DispatchScan
}

// Construct the tile status interface
auto tile_state = kernel_source.TileState();
error = CubDebug(tile_state.Init(num_tiles, allocations[0], allocation_sizes[0]));
error = CubDebug(kernel_source.TileStateInit(tile_state, num_tiles, allocations[0], allocation_sizes[0]));
if (cudaSuccess != error)
{
break;
Expand Down

0 comments on commit 811ba03

Please sign in to comment.