Skip to content

Commit

Permalink
intra-warp reductions and inter-warp shared memory reductions for B,T…
Browse files Browse the repository at this point in the history
…,C parallelism of softmax kernel nice
  • Loading branch information
karpathy committed Apr 9, 2024
1 parent c29d70a commit 8386e53
Showing 1 changed file with 170 additions and 0 deletions.
170 changes: 170 additions & 0 deletions dev/cuda/softmax_forward.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@ version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops ov
version 2 is a fused kernel that parallelizes over all of B,T,C
./softmax_forward 2
version 3 uses intra-warp reductions for maxval and sumval, must use block_size=32
./softmax_forward 3
version 4 uses both intra-warp reductions and shared memory for inter-warp reductions
so it can tolerate any block_size % 32 == 0. this is hopefully the most efficient version
./softmax_forward 4
*/

#include <stdio.h>
Expand Down Expand Up @@ -138,6 +145,150 @@ __global__ void softmax_forward_kernel2(float* out, float* inp, int N, int C) {
}
}

// warp-level reduction for finding the maximum value
__device__ float warpReduceMax(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
}
return val;
}

// warp-level reduction for summing values
__device__ float warpReduceSum(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xFFFFFFFF, val, offset);
}
return val;
}

__global__ void softmax_forward_kernel3(float* out, float* inp, int N, int C) {
// kernel must use block size of 32
extern __shared__ float shared[];
int idx = blockIdx.x;
int tid = threadIdx.x;
float* x = inp + idx * C;

// Thread coarsening and within-warp reduction for maxval
float maxval = -INFINITY;
for (int i = tid; i < C; i += blockDim.x) {
maxval = fmaxf(maxval, x[i]);
}
maxval = warpReduceMax(maxval);

// Broadcast maxval within the warp
float offset = __shfl_sync(0xFFFFFFFF, maxval, 0);

// Compute expf and write the result to global memory
for (int i = tid; i < C; i += blockDim.x) {
out[idx * C + i] = expf(x[i] - offset);
}

// Thread coarsening and within-warp reduction for sumval
x = out + idx * C;
float sumval = 0.0f;
for (int i = tid; i < C; i += blockDim.x) {
sumval += x[i];
}
sumval = warpReduceSum(sumval);

// Broadcast sumval within the warp
float sum = __shfl_sync(0xFFFFFFFF, sumval, 0);

// Divide the input values by the sum
for (int i = tid; i < C; i += blockDim.x) {
out[idx * C + i] = x[i] / sum;
}
}

__global__ void softmax_forward_kernel4(float* out, float* inp, int N, int C) {
// out is (N, C) just like inp. Each row of inp will get softmaxed.
// same as kernel3, but can handle any block size (multiple of 32)
// each row of C elements is handled by block_size threads
// furthermore, each block_size threads get executed in warps of 32 threads

// special reduction operations warpReduceMax/warpReduceSum are used for intra-warp reductions
// shared memory is used for inter-warp reduction
extern __shared__ float shared[];
int idx = blockIdx.x;
int tid = threadIdx.x;
int warpId = threadIdx.x / 32; // warp index within a block
int laneId = threadIdx.x % 32; // thread index within a warp

// the number of warps per block. recall that blockDim.x is block_size
int warpsPerBlock = blockDim.x / 32;

// shared[] must be allocated to have 2 * warpsPerBlock elements
// first half for max values, the second half for sum values
float* maxvals = shared;
float* sumvals = &shared[warpsPerBlock];

// one row of inp, i.e. inp[idx, :] of shape (C,)
float* x = inp + idx * C;

// first, thread coarsening by directly accessing global memory in series
float maxval = -INFINITY;
for (int i = tid; i < C; i += blockDim.x) {
maxval = fmaxf(maxval, x[i]);
}
// now within-warp reductions for maxval
maxval = warpReduceMax(maxval);

// the 0th thread of each warp writes the maxval of that warp to shared memory
if (laneId == 0) maxvals[warpId] = maxval;
__syncthreads();

// now the 0th thread reduces the maxvals in shared memory, i.e. across warps
if (tid == 0) {
float val = maxvals[tid];
for (int i = 1; i < warpsPerBlock; i++) {
val = fmaxf(val, maxvals[i]);
}
// store the final max in the first position
maxvals[0] = val;
}
__syncthreads();
// broadcast the max to all threads
float offset = maxvals[0];

// compute expf and write the result to global memory
for (int i = tid; i < C; i += blockDim.x) {
out[idx * C + i] = expf(x[i] - offset);
}

// okay now we calculated exp(x - max(x))
// step 2: sum all the values and divide by the sum

// thread coarsening for sum
x = out + idx * C;
float sumval = 0.0f;
for (int i = tid; i < C; i += blockDim.x) {
sumval += x[i];
}
// within-warp reduction for sumval
sumval = warpReduceSum(sumval);

// write sumval to shared memory
if (laneId == 0) sumvals[warpId] = sumval;
__syncthreads();

// inter-thread reduction of sum
if (tid == 0) {
float val = sumvals[tid];
for (int i = 1; i < warpsPerBlock; ++i) {
val += sumvals[i];
}
sumvals[0] = val;
}
__syncthreads();
// broadcast the sum to all threads
float sum = sumvals[0];

// divide the whole row by the sum
for (int i = tid; i < C; i += blockDim.x) {
out[idx * C + i] = x[i] / sum;
}
}

// ----------------------------------------------------------------------------
// kernel launcher

Expand All @@ -153,6 +304,19 @@ void softmax_forward2(float* out, float* inp, int N, int C, const int block_size
softmax_forward_kernel2<<<grid_size, block_size, shared_mem_size>>>(out, inp, N, C);
}

void softmax_forward3(float* out, float* inp, int N, int C, int block_size) {
block_size = 32; // awkward but ok. this one only works with block size 32
int grid_size = N;
size_t shared_mem_size = block_size * sizeof(float);
softmax_forward_kernel3<<<grid_size, block_size, shared_mem_size>>>(out, inp, N, C);
}

void softmax_forward4(float* out, float* inp, int N, int C, int block_size) {
int grid_size = N;
size_t shared_mem_size = 2 * block_size / 32 * sizeof(float);
softmax_forward_kernel4<<<grid_size, block_size, shared_mem_size>>>(out, inp, N, C);
}

// kernel version dispatch
void softmax_forward(int kernel_num, float* out, float* inp, int N, int C, const int block_size) {
switch (kernel_num) {
Expand All @@ -162,6 +326,12 @@ void softmax_forward(int kernel_num, float* out, float* inp, int N, int C, const
case 2:
softmax_forward2(out, inp, N, C, block_size);
break;
case 3:
softmax_forward3(out, inp, N, C, block_size);
break;
case 4:
softmax_forward4(out, inp, N, C, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
Expand Down

0 comments on commit 8386e53

Please sign in to comment.