Skip to content

Commit

Permalink
Rename to for_each_cancelled_block and extend docs
Browse files Browse the repository at this point in the history
  • Loading branch information
gonzalobg committed Feb 6, 2025
1 parent 7e99c76 commit 0b62b46
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 25 deletions.
47 changes: 34 additions & 13 deletions docs/libcudacxx/extended_api/work_stealing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,38 @@ Work stealing
template <int ThreadBlockRank = 3, typename UnaryFunction = ..unspecified..>
requires std::is_invocable_r_v<void, UnaryFunction, dim3>
__device__ void try_cancel_blocks(UnaryFunction uf);
__device__ void for_each_cancelled_block(UnaryFunction uf);
} // namespace cuda::experimental
**WARNING**: this is an experimental API.
**WARNING**: this API requires C++20 or newer.
**WARNING**: This is an **Experimental API**.

**WARNING**: This API requires C++20 or newer.

This API is useful to implement work-stealing at thread-block level granularity.
On devices with compute capability 10.0 or higher, it may leverage hardware acceleration for work-stealing.
When compared against alternative work distribution techniques like `grid-stride loops <https://developer.nvidia.com/blog/cuda-pro-tip-write-flexible-kernels-grid-stride-loops/>`__, which distribute load statically, or against other dynamic work distribution techniques using global memory concurrency, the main advantages of this API over these alternatives are:
- It performs work-stealing dynamically: thread blocks that finish work sooner may do more work than thread blocks whose work takes longer.
- It may cooperate with the GPU work-scheduler to respect work priorities and perform load-balancing.
- It may have lower work-stealing latency than global memory atomics.

- It performs work-stealing dynamically: thread blocks that finish work sooner may do more work than thread blocks whose work takes longer.
- It may cooperate with the GPU work-scheduler to respect work priorities and perform load-balancing.
- It may have lower work-stealing latency than global memory atomics.

For better performance, extract the thread block prologue - i.e. code and data that is common to all thread blocks (e.g. initialization, __shared__ memory allocations, common constants loaded to shared memory, etc.) - and thread block epilogue (e.g. writing back shared memory to global memory) outside the lambda passed to this API (see example below).

**Mandates**:
- ``ThreadBlockRank`` equals the rank of the thread block: ``1``, ``2``, or ``3`` for one-dimensional, two-dimensional, and three-dimensional thread blocks, respectively.

- ``ThreadBlockRank`` equals the rank of the thread block: ``1``, ``2``, or ``3`` for one-dimensional, two-dimensional, and three-dimensional thread blocks, respectively.

**Preconditions**:
- All threads of current thread-block call ``try_cancel_blocks`` exactly once.

- All threads of current thread-block call ``for_each_cancelled_block`` **exactly once**.

**Effects**:
- Invokes ``uf`` with ``dim3 == blockIdx``, then repetedly attempts to cancel the launch of a current grid thread block, and:
- on success, calls ``uf`` with that thread blocks ``blockIdx``,
- otherwise, it returns.

- Invokes ``uf`` with ``dim3 == blockIdx``, then repetedly attempts to cancel the launch of a current grid thread block, and:

- on success, calls ``uf`` with that thread blocks ``blockIdx`` and repeats,
- otherwise it failed to cancel the launch of a thread block and it returns.

Example
-------
Expand All @@ -41,17 +49,30 @@ This example shows how to perform work-stealing at thread-block granularity usin

.. code:: cuda
// Before:
#include <cuda/math>
#include <cuda/try_cancel>
__global__ void vec_add(int* a, int* b, int* c, int n) {
cuda::experimental::try_cancel_blocks<1>([=](dim3 tb) {
// Extract common prologue outside the lambda, e.g.,
// - __shared__ or global (malloc) memory allocation
// - common initialization code
// - etc.
cuda::experimental::for_each_cancelled_block<1>([=](dim3 tb) {
int idx = threadIdx.x + tb.x * blockDim.x;
if (idx < n) {
c[idx] += a[idx] + b[idx];
}
});
// Note: Calling try_cancel_blocks<1> again from this
// Note: Calling for_each_cancelled_block<1> again from this
// thread block exhibits undefined behavior.
// Extract common epilogue outside the lambda, e.g.,
// - write back shared memory to global memory
// - external synchronization
// - global memory deallocation (free)
// - etc.
}
int main() {
Expand Down
12 changes: 6 additions & 6 deletions libcudacxx/include/cuda/try_cancel
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ _CCCL_NODISCARD _CCCL_DEVICE _CCCL_HIDE_FROM_ABI int __cluster_get_dim(__int128
/// - Exactly one thread block thread shall call this API with `__is_leader` equals `true`.
template <int __ThreadBlockDim = 3, typename __UnaryFunction = __detail::__empty_t>
requires std::is_invocable_r_v<void, __UnaryFunction, dim3>
_CCCL_DEVICE _CCCL_HIDE_FROM_ABI void __try_cancel_blocks(bool __is_leader, __UnaryFunction __uf) {
_CCCL_DEVICE _CCCL_HIDE_FROM_ABI void __for_each_cancelled_block(bool __is_leader, __UnaryFunction __uf) {
static_assert(__ThreadBlockRank >= 1 && __ThreadBlockRank <= 3,
"ThreadBlockRank out-of-range [1, 3].");
dim3 __block_idx = dim3(blockIdx.x, 1, 1);
Expand Down Expand Up @@ -193,15 +193,15 @@ _CCCL_DEVICE _CCCL_HIDE_FROM_ABI void __try_cancel_blocks(bool __is_leader, __Un
/// - Exactly one thread block thread shall call this API with `__is_leader` equals `true`.
template <int __ThreadBlockRank = 3, typename __UnaryFunction = __detail::__empty_t>
requires std::is_invocable_r_v<void, __UnaryFunction, dim3>
_CCCL_DEVICE _CCCL_HIDE_FROM_ABI void try_cancel_blocks(__UnaryFunction __uf) {
_CCCL_DEVICE _CCCL_HIDE_FROM_ABI void for_each_cancelled_block(__UnaryFunction __uf) {
static_assert(__ThreadBlockRank >= 1 && __ThreadBlockRank <= 3,
"try_cancel_blocks<ThreadBlockRank>: ThreadBlockRank out-of-range [1, 3].");
"for_each_cancelled_block<ThreadBlockRank>: ThreadBlockRank out-of-range [1, 3].");
if constexpr (__ThreadBlockRank == 1) {
__detail::__try_cancel_blocks<1>(threadIdx.x == 0, ::cuda::std::move(__uf));
__detail::__for_each_cancelled_block<1>(threadIdx.x == 0, ::cuda::std::move(__uf));
} else if constexpr (__ThreadBlockRank == 2) {
__detail::__try_cancel_blocks<2>(threadIdx.x == 0 && threadIdx.y == 0, ::cuda::std::move(__uf));
__detail::__for_each_cancelled_block<2>(threadIdx.x == 0 && threadIdx.y == 0, ::cuda::std::move(__uf));
} else if constexpr (__ThreadBlockRank == 3) {
__detail::__try_cancel_blocks<3>(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0, ::cuda::std::move(__uf));
__detail::__for_each_cancelled_block<3>(threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0, ::cuda::std::move(__uf));
}
}

Expand Down
12 changes: 6 additions & 6 deletions libcudacxx/test/libcudacxx/cuda/try_cancel/try_cancel.pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,37 +42,37 @@ __device__ void vec_add_impl3(int* a, int* b, int* c, int n, dim3 tb) {
}

__global__ void vec_add_det1(int* a, int* b, int* c, int n, int tidx = 0) {
cuda::experimental::__detail::__try_cancel_blocks<1>(threadIdx.x == tidx, [=](dim3 tb) {
cuda::experimental::__detail::__for_each_cancelled_block<1>(threadIdx.x == tidx, [=](dim3 tb) {
vec_add_impl1(a, b, c, n, tb);
});
}

__global__ void vec_add_det2(int* a, int* b, int* c, int n, int tidx = 0) {
cuda::experimental::__detail::__try_cancel_blocks<2>(threadIdx.x == tidx && threadIdx.y == tidx, [=](dim3 tb) {
cuda::experimental::__detail::__for_each_cancelled_block<2>(threadIdx.x == tidx && threadIdx.y == tidx, [=](dim3 tb) {
vec_add_impl2(a, b, c, n, tb);
});
}

__global__ void vec_add_det3(int* a, int* b, int* c, int n, int tidx = 0) {
cuda::experimental::__detail::__try_cancel_blocks<3>(threadIdx.x == tidx && threadIdx.y == tidx && threadIdx.z == tidx, [=](dim3 tb) {
cuda::experimental::__detail::__for_each_cancelled_block<3>(threadIdx.x == tidx && threadIdx.y == tidx && threadIdx.z == tidx, [=](dim3 tb) {
vec_add_impl3(a, b, c, n, tb);
});
}

__global__ void vec_add1(int* a, int* b, int* c, int n, int tidx = 0) {
cuda::experimental::try_cancel_blocks<1>([=](dim3 tb) {
cuda::experimental::for_each_cancelled_block<1>([=](dim3 tb) {
vec_add_impl1(a, b, c, n, tb);
});
}

__global__ void vec_add2(int* a, int* b, int* c, int n, int tidx = 0) {
cuda::experimental::try_cancel_blocks<2>([=](dim3 tb) {
cuda::experimental::for_each_cancelled_block<2>([=](dim3 tb) {
vec_add_impl2(a, b, c, n, tb);
});
}

__global__ void vec_add3(int* a, int* b, int* c, int n, int tidx = 0) {
cuda::experimental::try_cancel_blocks<3>([=](dim3 tb) {
cuda::experimental::for_each_cancelled_block<3>([=](dim3 tb) {
vec_add_impl3(a, b, c, n, tb);
});
}
Expand Down

0 comments on commit 0b62b46

Please sign in to comment.