Skip to content

Commit

Permalink
Introduce detail functions to allocate/initialize tile state
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashwin Srinath authored and shwina committed Feb 2, 2025
1 parent 5f1a400 commit a089269
Showing 1 changed file with 76 additions and 35 deletions.
111 changes: 76 additions & 35 deletions cub/cub/agent/single_pass_scan_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,65 @@ struct tile_state_with_memory_order
return tile_state.template LoadValid<Order>(tile_idx);
}
};

_CCCL_HOST_DEVICE _CCCL_FORCEINLINE constexpr int num_tiles_to_num_tile_states(int num_tiles)
{
return CUB_PTX_WARP_THREADS + num_tiles;
}

_CCCL_HOST_DEVICE _CCCL_FORCEINLINE size_t
tile_state_allocation_size(int bytes_per_description, int bytes_per_payload, int num_tiles)
{
// Specify storage allocation requirements
size_t allocation_sizes[3];

int num_tile_states = num_tiles_to_num_tile_states(num_tiles);

// bytes needed for tile status descriptors
allocation_sizes[0] = num_tile_states * bytes_per_description;

// bytes needed for partials
allocation_sizes[1] = num_tile_states * bytes_per_payload;

// bytes needed for inclusives
allocation_sizes[2] = num_tile_states * bytes_per_payload;

// Set the necessary size of the blob
size_t temp_storage_bytes = 0;
void* allocations[3] = {};
AliasTemporaries(nullptr, temp_storage_bytes, allocations, allocation_sizes);

return temp_storage_bytes;
};

_CCCL_HOST_DEVICE _CCCL_FORCEINLINE cudaError_t tile_state_init(
int bytes_per_description,
int bytes_per_payload,
int num_tiles,
void* d_temp_storage,
size_t temp_storage_bytes,
void* (&allocations)[3])
{
// Specify storage allocation requirements
size_t allocation_sizes[3];

int num_tile_states = num_tiles_to_num_tile_states(num_tiles);

// bytes needed for tile status descriptors
allocation_sizes[0] = num_tile_states * bytes_per_description;

// bytes needed for partials
allocation_sizes[1] = num_tile_states * bytes_per_payload;

// bytes needed for inclusives
allocation_sizes[2] = num_tile_states * bytes_per_payload;

// Set the necessary size of the blob
AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes);

return cudaSuccess;
}

} // namespace detail

/**
Expand Down Expand Up @@ -583,6 +642,9 @@ struct ScanTileState<T, true>
// Device storage
TxnWord* d_tile_descriptors;

static constexpr size_t description_bytes_per_tile = sizeof(TxnWord);
static constexpr size_t payload_bytes_per_tile = 0;

/// Constructor
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanTileState()
: d_tile_descriptors(nullptr)
Expand Down Expand Up @@ -618,10 +680,11 @@ struct ScanTileState<T, true>
* @param[out] temp_storage_bytes
* Size in bytes of \t d_temp_storage allocation
*/
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE static cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes)
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE static constexpr cudaError_t
AllocationSize(int num_tiles, size_t& temp_storage_bytes)
{
// bytes needed for tile status descriptors
temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TxnWord);
temp_storage_bytes =
detail::tile_state_allocation_size(description_bytes_per_tile, payload_bytes_per_tile, num_tiles);
return cudaSuccess;
}

Expand Down Expand Up @@ -782,6 +845,9 @@ struct ScanTileState<T, false>
T* d_tile_partial;
T* d_tile_inclusive;

static constexpr size_t description_bytes_per_tile = sizeof(StatusWord);
static constexpr size_t payload_bytes_per_tile = sizeof(Uninitialized<T>);

/// Constructor
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanTileState()
: d_tile_status(nullptr)
Expand Down Expand Up @@ -810,25 +876,12 @@ struct ScanTileState<T, false>
do
{
void* allocations[3] = {};
size_t allocation_sizes[3];

// bytes needed for tile status descriptors
allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord);

// bytes needed for partials
allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>);

// bytes needed for inclusives
allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>);

// Compute allocation pointers into the single storage blob
error = CubDebug(AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes));

error = detail::tile_state_init(
description_bytes_per_tile, payload_bytes_per_tile, num_tiles, d_temp_storage, temp_storage_bytes, allocations);
if (cudaSuccess != error)
{
break;
}

// Alias the offsets
d_tile_status = reinterpret_cast<StatusWord*>(allocations[0]);
d_tile_partial = reinterpret_cast<T*>(allocations[1]);
Expand All @@ -847,25 +900,13 @@ struct ScanTileState<T, false>
* @param[out] temp_storage_bytes
* Size in bytes of \t d_temp_storage allocation
*/
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE static cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes)
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE static constexpr cudaError_t
AllocationSize(int num_tiles, size_t& temp_storage_bytes)
{
// Specify storage allocation requirements
size_t allocation_sizes[3];

// bytes needed for tile status descriptors
allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord);

// bytes needed for partials
allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>);

// bytes needed for inclusives
allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>);

// Set the necessary size of the blob
void* allocations[3] = {};
return CubDebug(AliasTemporaries(nullptr, temp_storage_bytes, allocations, allocation_sizes));
temp_storage_bytes =
detail::tile_state_allocation_size(description_bytes_per_tile, payload_bytes_per_tile, num_tiles);
return cudaSuccess;
}

/**
* Initialize (from device)
*/
Expand Down

0 comments on commit a089269

Please sign in to comment.