forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsoftmax_forward.cu
433 lines (372 loc) · 13.7 KB
/
softmax_forward.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
/*
Kernels for softmax forward pass.
Compile example:
nvcc -O3 --use_fast_math softmax_forward.cu -o softmax_forward
version 1 is naive port from CPU code to kernel: parallelizes over B,T, loops over C
./softmax_forward 1
version 2 is a fused kernel that parallelizes over all of B,T,C
./softmax_forward 2
version 3 uses intra-warp reductions for maxval and sumval, must use block_size=32
./softmax_forward 3
version 4 uses both intra-warp reductions and shared memory for inter-warp reductions
so it can tolerate any block_size % 32 == 0. this is hopefully the most efficient version
./softmax_forward 4
*/
#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>
// ----------------------------------------------------------------------------
// CUDA utils
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
// error checking
void cudaCheck(cudaError_t error, const char *file, int line) {
if (error != cudaSuccess) {
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line,
cudaGetErrorString(error));
exit(EXIT_FAILURE);
}
};
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))
// ----------------------------------------------------------------------------
// CPU code reference
void softmax_forward_cpu(float* out, float* inp, int N, int C) {
// inp is (N, C)
// out is (N, C), each row of inp will get softmaxed
for (int i = 0; i < N; i++) {
float* inp_row = inp + i * C;
float* out_row = out + i * C;
float maxval = -INFINITY;
for (int j = 0; j < C; j++) {
if (inp_row[j] > maxval) {
maxval = inp_row[j];
}
}
float sum = 0.0f;
for (int j = 0; j < C; j++) {
out_row[j] = expf(inp_row[j] - maxval);
sum += out_row[j];
}
for (int j = 0; j < C; j++) {
out_row[j] /= sum;
}
}
}
// ----------------------------------------------------------------------------
// GPU kernels
__global__ void softmax_forward_kernel1(float* out, float* inp, int N, int C) {
// inp is (N, C)
// out is (N, C), each row of inp will get softmaxed
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < N) {
float* inp_row = inp + i * C;
float* out_row = out + i * C;
float maxval = -INFINITY;
for (int j = 0; j < C; j++) {
if (inp_row[j] > maxval) {
maxval = inp_row[j];
}
}
float sum = 0.0f;
for (int j = 0; j < C; j++) {
out_row[j] = expf(inp_row[j] - maxval);
sum += out_row[j];
}
for (int j = 0; j < C; j++) {
out_row[j] /= sum;
}
}
}
__global__ void softmax_forward_kernel2(float* out, float* inp, int N, int C) {
// inp is (N, C)
// in each row of C elements, first calculates maxval, then returns expf(val - maxval)
extern __shared__ float shared[];
int idx = blockIdx.x; // ranges [0, N)
int tid = threadIdx.x; // ranges [0, block_size)
int block_size = blockDim.x;
float* x = inp + idx * C; // idx-th row of inp
// thread coarsening
float maxval = -INFINITY;
for (int i = tid; i < C; i += block_size) {
maxval = fmaxf(maxval, x[i]);
}
shared[tid] = maxval;
__syncthreads();
// reductions
for (int stride = block_size / 2; stride >= 1; stride /= 2) {
__syncthreads();
if (tid < stride) {
shared[tid] = fmaxf(shared[tid], shared[tid + stride]);
}
}
float offset = shared[0];
__syncthreads();
// compute expf and write the result to global memory
for (int i = tid; i < C; i += block_size) {
out[idx * C + i] = expf(x[i] - offset);
}
__syncthreads();
// thread coarsening again, for the sum
x = out + idx * C; // idx-th row of out
float sumval = 0.0f;
for (int i = tid; i < C; i += block_size) {
sumval += x[i];
}
shared[tid] = sumval;
__syncthreads();
// reductions
for (int stride = block_size / 2; stride >= 1; stride /= 2) {
__syncthreads();
if (tid < stride) {
shared[tid] += shared[tid + stride];
}
}
// broadcast the sum to all threads in the block
__syncthreads();
float sum = shared[0];
// divide the input values by the sum
for (int i = tid; i < C; i += block_size) {
out[idx * C + i] = x[i] / sum;
}
}
// warp-level reduction for finding the maximum value
__device__ float warpReduceMax(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
}
return val;
}
// warp-level reduction for summing values
__device__ float warpReduceSum(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xFFFFFFFF, val, offset);
}
return val;
}
__global__ void softmax_forward_kernel3(float* out, float* inp, int N, int C) {
// kernel must use block size of 32
extern __shared__ float shared[];
int idx = blockIdx.x;
int tid = threadIdx.x;
float* x = inp + idx * C;
// Thread coarsening and within-warp reduction for maxval
float maxval = -INFINITY;
for (int i = tid; i < C; i += blockDim.x) {
maxval = fmaxf(maxval, x[i]);
}
maxval = warpReduceMax(maxval);
// Broadcast maxval within the warp
float offset = __shfl_sync(0xFFFFFFFF, maxval, 0);
// Compute expf and write the result to global memory
for (int i = tid; i < C; i += blockDim.x) {
out[idx * C + i] = expf(x[i] - offset);
}
// Thread coarsening and within-warp reduction for sumval
x = out + idx * C;
float sumval = 0.0f;
for (int i = tid; i < C; i += blockDim.x) {
sumval += x[i];
}
sumval = warpReduceSum(sumval);
// Broadcast sumval within the warp
float sum = __shfl_sync(0xFFFFFFFF, sumval, 0);
// Divide the input values by the sum
for (int i = tid; i < C; i += blockDim.x) {
out[idx * C + i] = x[i] / sum;
}
}
__global__ void softmax_forward_kernel4(float* out, float* inp, int N, int C) {
// out is (N, C) just like inp. Each row of inp will get softmaxed.
// same as kernel3, but can handle any block size (multiple of 32)
// each row of C elements is handled by block_size threads
// furthermore, each block_size threads get executed in warps of 32 threads
// special reduction operations warpReduceMax/warpReduceSum are used for intra-warp reductions
// shared memory is used for inter-warp reduction
extern __shared__ float shared[];
int idx = blockIdx.x;
int tid = threadIdx.x;
int warpId = threadIdx.x / 32; // warp index within a block
int laneId = threadIdx.x % 32; // thread index within a warp
// the number of warps per block. recall that blockDim.x is block_size
int warpsPerBlock = blockDim.x / 32;
// shared[] must be allocated to have 2 * warpsPerBlock elements
// first half for max values, the second half for sum values
float* maxvals = shared;
float* sumvals = &shared[warpsPerBlock];
// one row of inp, i.e. inp[idx, :] of shape (C,)
float* x = inp + idx * C;
// first, thread coarsening by directly accessing global memory in series
float maxval = -INFINITY;
for (int i = tid; i < C; i += blockDim.x) {
maxval = fmaxf(maxval, x[i]);
}
// now within-warp reductions for maxval
maxval = warpReduceMax(maxval);
// the 0th thread of each warp writes the maxval of that warp to shared memory
if (laneId == 0) maxvals[warpId] = maxval;
__syncthreads();
// now the 0th thread reduces the maxvals in shared memory, i.e. across warps
if (tid == 0) {
float val = maxvals[tid];
for (int i = 1; i < warpsPerBlock; i++) {
val = fmaxf(val, maxvals[i]);
}
// store the final max in the first position
maxvals[0] = val;
}
__syncthreads();
// broadcast the max to all threads
float offset = maxvals[0];
// compute expf and write the result to global memory
for (int i = tid; i < C; i += blockDim.x) {
out[idx * C + i] = expf(x[i] - offset);
}
// okay now we calculated exp(x - max(x))
// step 2: sum all the values and divide by the sum
// thread coarsening for sum
x = out + idx * C;
float sumval = 0.0f;
for (int i = tid; i < C; i += blockDim.x) {
sumval += x[i];
}
// within-warp reduction for sumval
sumval = warpReduceSum(sumval);
// write sumval to shared memory
if (laneId == 0) sumvals[warpId] = sumval;
__syncthreads();
// inter-thread reduction of sum
if (tid == 0) {
float val = sumvals[tid];
for (int i = 1; i < warpsPerBlock; ++i) {
val += sumvals[i];
}
sumvals[0] = val;
}
__syncthreads();
// broadcast the sum to all threads
float sum = sumvals[0];
// divide the whole row by the sum
for (int i = tid; i < C; i += blockDim.x) {
out[idx * C + i] = x[i] / sum;
}
}
// ----------------------------------------------------------------------------
// kernel launcher
void softmax_forward1(float* out, float* inp, int N, int C, const int block_size) {
const int grid_size = CEIL_DIV(N, block_size);
softmax_forward_kernel1<<<grid_size, block_size>>>(out, inp, N, C);
cudaCheck(cudaGetLastError());
}
void softmax_forward2(float* out, float* inp, int N, int C, const int block_size) {
int grid_size = N;
size_t shared_mem_size = block_size * sizeof(float);
softmax_forward_kernel2<<<grid_size, block_size, shared_mem_size>>>(out, inp, N, C);
}
void softmax_forward3(float* out, float* inp, int N, int C, int block_size) {
block_size = 32; // awkward but ok. this one only works with block size 32
int grid_size = N;
size_t shared_mem_size = block_size * sizeof(float);
softmax_forward_kernel3<<<grid_size, block_size, shared_mem_size>>>(out, inp, N, C);
}
void softmax_forward4(float* out, float* inp, int N, int C, int block_size) {
int grid_size = N;
size_t shared_mem_size = 2 * block_size / 32 * sizeof(float);
softmax_forward_kernel4<<<grid_size, block_size, shared_mem_size>>>(out, inp, N, C);
}
// kernel version dispatch
void softmax_forward(int kernel_num, float* out, float* inp, int N, int C, const int block_size) {
switch (kernel_num) {
case 1:
softmax_forward1(out, inp, N, C, block_size);
break;
case 2:
softmax_forward2(out, inp, N, C, block_size);
break;
case 3:
softmax_forward3(out, inp, N, C, block_size);
break;
case 4:
softmax_forward4(out, inp, N, C, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
}
}
// ----------------------------------------------------------------------------
// random utils
float* make_random_float(int N) {
float* arr = (float*)malloc(N * sizeof(float));
for (int i = 0; i < N; i++) {
arr[i] = ((float)rand() / RAND_MAX) * 2.0 - 1.0;
}
return arr;
}
// ----------------------------------------------------------------------------
int main(int argc, char **argv) {
srand(0);
int B = 8;
int T = 1024;
int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));
// create host memory of random numbers
float* out = (float*)malloc(B * T * T * sizeof(float));
float* inp = make_random_float(B * T * T);
// move to GPU
float* d_out;
float* d_inp;
cudaCheck(cudaMalloc(&d_out, B * T * T * sizeof(float)));
cudaCheck(cudaMalloc(&d_inp, B * T * T * sizeof(float)));
cudaCheck(cudaMemcpy(d_inp, inp, B * T * T * sizeof(float), cudaMemcpyHostToDevice));
// read kernel_num from command line
int kernel_num = 1;
if (argc > 1) {
kernel_num = atoi(argv[1]);
}
printf("Using kernel %d\n", kernel_num);
int block_sizes[] = {32, 64, 128, 256, 512, 1024};
float* out_gpu = (float*)malloc(B * T * T * sizeof(float));
// first check the correctness of the kernel
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];
softmax_forward_cpu(out, inp, B * T, T);
softmax_forward(kernel_num, d_out, d_inp, B * T, T, block_size);
cudaCheck(cudaMemcpy(out_gpu, d_out, B * T * T * sizeof(float), cudaMemcpyDeviceToHost));
for (int i = 0; i < B * T * T; i++) {
// print the first few comparisons
if (i < 5) {
printf("%f %f\n", out[i], out_gpu[i]);
}
// ensure correctness for all elements
if (fabs(out[i] - out_gpu[i]) > 1e-4) {
printf("Mismatch at %d: %f vs %f\n", i, out[i], out_gpu[i]);
exit(1);
}
}
printf("Results match at block_size=%d\n", block_size);
}
// time the kernel at different block sizes
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];
int repeat_times = 1000;
cudaEvent_t start, stop;
cudaCheck(cudaEventCreate(&start));
cudaCheck(cudaEventCreate(&stop));
cudaCheck(cudaEventRecord(start, 0));
for (int i = 0; i < repeat_times; i++) {
softmax_forward(kernel_num, d_out, d_inp, B * T, T, block_size);
}
cudaCheck(cudaEventRecord(stop, 0));
cudaCheck(cudaEventSynchronize(start));
cudaCheck(cudaEventSynchronize(stop));
float elapsed_time;
cudaCheck(cudaEventElapsedTime(&elapsed_time, start, stop));
printf("block_size %4d | time %f ms\n", block_size, elapsed_time / repeat_times);
}
// free memory
free(out);
free(inp);
free(out_gpu);
cudaCheck(cudaFree(d_out));
cudaCheck(cudaFree(d_inp));
return 0;
}