Skip to content

Commit

Permalink
Add span to example and templated block size (#2470)
Browse files Browse the repository at this point in the history
* add span to example and template block size
  • Loading branch information
Kh4ster authored Sep 28, 2024
1 parent 7c668e8 commit e3800d7
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,30 +41,31 @@ The sum of each block is then reduced to a single value using an atomic add via

It then shows how the same reduction can be done using Thrust's `reduce` algorithm and compares the results.

[Try it live on Godbolt!](https://godbolt.org/z/x4G73af9a)
[Try it live on Godbolt!](https://godbolt.org/z/aMx4j9f4T)

```cpp
#include <thrust/execution_policy.h>
#include <thrust/device_vector.h>
#include <cub/block/block_reduce.cuh>
#include <cuda/atomic>
#include <cuda/cmath>
#include <cuda/std/span>
#include <cstdio>

constexpr int block_size = 256;

__global__ void reduce(int const* data, int* result, int N) {
template <int block_size>
__global__ void reduce(cuda::std::span<int const> data, cuda::std::span<int> result) {
using BlockReduce = cub::BlockReduce<int, block_size>;
__shared__ typename BlockReduce::TempStorage temp_storage;

int const index = threadIdx.x + blockIdx.x * blockDim.x;
int sum = 0;
if (index < N) {
if (index < data.size()) {
sum += data[index];
}
sum = BlockReduce(temp_storage).Sum(sum);

if (threadIdx.x == 0) {
cuda::atomic_ref<int, cuda::thread_scope_device> atomic_result(*result);
cuda::atomic_ref<int, cuda::thread_scope_device> atomic_result(result.front());
atomic_result.fetch_add(sum, cuda::memory_order_relaxed);
}
}
Expand All @@ -80,10 +81,10 @@ int main() {
thrust::device_vector<int> kernel_result(1);

// Compute the sum reduction of `data` using a custom kernel
int const num_blocks = (N + block_size - 1) / block_size;
reduce<<<num_blocks, block_size>>>(thrust::raw_pointer_cast(data.data()),
thrust::raw_pointer_cast(kernel_result.data()),
N);
constexpr int block_size = 256;
int const num_blocks = cuda::ceil_div(N, block_size);
reduce<block_size><<<num_blocks, block_size>>>(cuda::std::span<int const>(thrust::raw_pointer_cast(data.data()), data.size()),
cuda::std::span<int>(thrust::raw_pointer_cast(kernel_result.data()), 1));

auto const err = cudaDeviceSynchronize();
if (err != cudaSuccess) {
Expand Down

0 comments on commit e3800d7

Please sign in to comment.