Skip to content

Commit

Permalink
Consolidating AllocationSize implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
Ashwin Srinath authored and shwina committed Feb 2, 2025
1 parent 5f1a400 commit 8bbc9a6
Showing 1 changed file with 74 additions and 20 deletions.
94 changes: 74 additions & 20 deletions cub/cub/agent/single_pass_scan_operators.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,65 @@ struct tile_state_with_memory_order
return tile_state.template LoadValid<Order>(tile_idx);
}
};

_CCCL_HOST_DEVICE _CCCL_FORCEINLINE constexpr int num_tiles_to_num_tile_states(int num_tiles)
{
return CUB_PTX_WARP_THREADS + num_tiles;
}

_CCCL_HOST_DEVICE _CCCL_FORCEINLINE size_t
tile_state_allocation_size(int bytes_per_description, int bytes_per_payload, int num_tiles)
{
// Specify storage allocation requirements
size_t allocation_sizes[3];

int num_tile_states = num_tiles_to_num_tile_states(num_tiles);

// bytes needed for tile status descriptors
allocation_sizes[0] = num_tile_states * bytes_per_description;

// bytes needed for partials
allocation_sizes[1] = num_tile_states * bytes_per_payload;

// bytes needed for inclusives
allocation_sizes[2] = num_tile_states * bytes_per_payload;

// Set the necessary size of the blob
size_t temp_storage_bytes = 0;
void* allocations[3] = {};
AliasTemporaries(nullptr, temp_storage_bytes, allocations, allocation_sizes);

return temp_storage_bytes;
};

_CCCL_HOST_DEVICE _CCCL_FORCEINLINE cudaError_t tile_state_init(
int bytes_per_description,
int bytes_per_payload,
int num_tiles,
void* d_temp_storage,
size_t temp_storage_bytes,
void* (&allocations)[3])
{
// Specify storage allocation requirements
size_t allocation_sizes[3];

int num_tile_states = num_tiles_to_num_tile_states(num_tiles);

// bytes needed for tile status descriptors
allocation_sizes[0] = num_tile_states * bytes_per_description;

// bytes needed for partials
allocation_sizes[1] = num_tile_states * bytes_per_payload;

// bytes needed for inclusives
allocation_sizes[2] = num_tile_states * bytes_per_payload;

// Set the necessary size of the blob
AliasTemporaries(d_temp_storage, temp_storage_bytes, allocations, allocation_sizes);

return cudaSuccess;
}

} // namespace detail

/**
Expand Down Expand Up @@ -583,6 +642,9 @@ struct ScanTileState<T, true>
// Device storage
TxnWord* d_tile_descriptors;

static constexpr size_t description_bytes_per_tile = sizeof(TxnWord);
static constexpr size_t payload_bytes_per_tile = 0;

/// Constructor
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanTileState()
: d_tile_descriptors(nullptr)
Expand Down Expand Up @@ -618,10 +680,11 @@ struct ScanTileState<T, true>
* @param[out] temp_storage_bytes
* Size in bytes of \t d_temp_storage allocation
*/
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE static cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes)
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE static constexpr cudaError_t
AllocationSize(int num_tiles, size_t& temp_storage_bytes)
{
// bytes needed for tile status descriptors
temp_storage_bytes = (num_tiles + TILE_STATUS_PADDING) * sizeof(TxnWord);
temp_storage_bytes =
detail::tile_state_allocation_size(description_bytes_per_tile, payload_bytes_per_tile, num_tiles);
return cudaSuccess;
}

Expand Down Expand Up @@ -782,6 +845,9 @@ struct ScanTileState<T, false>
T* d_tile_partial;
T* d_tile_inclusive;

static constexpr size_t description_bytes_per_tile = sizeof(StatusWord);
static constexpr size_t payload_bytes_per_tile = sizeof(Uninitialized<T>);

/// Constructor
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE ScanTileState()
: d_tile_status(nullptr)
Expand Down Expand Up @@ -847,25 +913,13 @@ struct ScanTileState<T, false>
* @param[out] temp_storage_bytes
* Size in bytes of \t d_temp_storage allocation
*/
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE static cudaError_t AllocationSize(int num_tiles, size_t& temp_storage_bytes)
_CCCL_HOST_DEVICE _CCCL_FORCEINLINE static constexpr cudaError_t
AllocationSize(int num_tiles, size_t& temp_storage_bytes)
{
// Specify storage allocation requirements
size_t allocation_sizes[3];

// bytes needed for tile status descriptors
allocation_sizes[0] = (num_tiles + TILE_STATUS_PADDING) * sizeof(StatusWord);

// bytes needed for partials
allocation_sizes[1] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>);

// bytes needed for inclusives
allocation_sizes[2] = (num_tiles + TILE_STATUS_PADDING) * sizeof(Uninitialized<T>);

// Set the necessary size of the blob
void* allocations[3] = {};
return CubDebug(AliasTemporaries(nullptr, temp_storage_bytes, allocations, allocation_sizes));
temp_storage_bytes =
detail::tile_state_allocation_size(description_bytes_per_tile, payload_bytes_per_tile, num_tiles);
return cudaSuccess;
}

/**
* Initialize (from device)
*/
Expand Down

0 comments on commit 8bbc9a6

Please sign in to comment.