forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathattention_forward.cu
813 lines (705 loc) · 28.3 KB
/
attention_forward.cu
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
/*
Kernels for attention forward pass.
Compile example:
nvcc -O3 --use_fast_math attention_forward.cu -o attention_forward -lcublas
version 1 is naive port from CPU code to kernel, parallelize over batch, time, heads only
./attention_forward 1
version 2 is a naive implementation of flash attention, taken, adapted from
https://github.com/tspeterkim/flash-attention-minimal
and with help from
https://github.com/leloykun/flash-hyperbolic-attention-minimal
sadly, this flash attention version seems about 3X slower than the naive version
./attention_forward 2
version 3 is a cuBLAS + softmax version, similar to the PyTorch implementation
cuBLAS is used both to calculate the QK^T and the final weighted sum
the softmax is calculated using a custom, efficient kernel as well
this turns out to be ~20X faster than (1) nice
./attention_forward 3
*/
#include <stdio.h>
#include <stdlib.h>
#include <cublas_v2.h>
#include <cuda_runtime.h>
// ----------------------------------------------------------------------------
// CUDA utils
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
// error checking
void cudaCheck(cudaError_t error, const char *file, int line) {
if (error != cudaSuccess) {
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line,
cudaGetErrorString(error));
exit(EXIT_FAILURE);
}
};
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))
// ----------------------------------------------------------------------------
// CPU code reference
void attention_forward_cpu(float* out, float* preatt, float* att,
float* inp,
int B, int T, int C, int NH) {
// input is (B, T, 3C) Q,K,V
// preatt, att are (B, NH, T, T)
// output is (B, T, C)
int C3 = C*3;
int hs = C / NH; // head size
float scale = 1.0 / sqrtf(hs);
for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
for (int h = 0; h < NH; h++) {
float* query_t = inp + b * T * C3 + t * C3 + h * hs;
float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;
float* att_bth = att + b*NH*T*T + h*T*T + t*T;
// pass 1: calculate query dot key and maxval
float maxval = -10000.0f; // TODO something better
for (int t2 = 0; t2 <= t; t2++) {
float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key
// (query_t) dot (key_t2)
float val = 0.0f;
for (int i = 0; i < hs; i++) {
val += query_t[i] * key_t2[i];
}
val *= scale;
if (val > maxval) {
maxval = val;
}
preatt_bth[t2] = val;
}
// pad with -INFINITY outside of autoregressive region for debugging comparisons
for (int t2 = t+1; t2 < T; t2++) {
preatt_bth[t2] = -INFINITY;
}
// pass 2: calculate the exp and keep track of sum
float expsum = 0.0f;
for (int t2 = 0; t2 <= t; t2++) {
float expv = expf(preatt_bth[t2] - maxval);
expsum += expv;
att_bth[t2] = expv;
}
float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;
// pass 3: normalize to get the softmax
for (int t2 = 0; t2 < T; t2++) {
if (t2 <= t) {
att_bth[t2] *= expsum_inv;
} else {
// causal attention mask. not strictly necessary to set to zero here
// only doing this explicitly for debugging and checking to PyTorch
att_bth[t2] = 0.0f;
}
}
// pass 4: accumulate weighted values into the output of attention
float* out_bth = out + b * T * C + t * C + h * hs;
for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }
for (int t2 = 0; t2 <= t; t2++) {
float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value
float att_btht2 = att_bth[t2];
for (int i = 0; i < hs; i++) {
out_bth[i] += att_btht2 * value_t2[i];
}
}
}
}
}
}
// ----------------------------------------------------------------------------
// GPU kernels
__global__ void attention_query_key_kernel1(float* preatt, float* inp,
int B, int T, int C, int NH) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total_threads = B * NH * T * T;
if (idx < total_threads) {
int t2 = idx % T;
int t = (idx / T) % T;
if (t2 > t) {
// autoregressive mask
preatt[idx] = -INFINITY;
return;
}
int h = (idx / (T * T)) % NH;
int b = idx / (NH * T * T);
int C3 = C*3;
int hs = C / NH; // head size
float* query_t = inp + b * T * C3 + t * C3 + h * hs;
float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key
// (query_t) dot (key_t2)
float val = 0.0f;
for (int i = 0; i < hs; i++) {
val += query_t[i] * key_t2[i];
}
val *= 1.0 / sqrtf(hs);
preatt[idx] = val;
}
}
__global__ void attention_softmax_kernel1(float* att, float* preatt,
int B, int T, int NH) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total_threads = B * T * NH;
if (idx < total_threads) {
int h = idx % NH;
int t = (idx / NH) % T;
int b = idx / (NH * T);
float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;
float* att_bth = att + b*NH*T*T + h*T*T + t*T;
// find maxval
float maxval = -10000.0f; // TODO something better
for (int t2 = 0; t2 <= t; t2++) {
if (preatt_bth[t2] > maxval) {
maxval = preatt_bth[t2];
}
}
// calculate the exp and keep track of sum
float expsum = 0.0f;
for (int t2 = 0; t2 <= t; t2++) {
float expv = expf(preatt_bth[t2] - maxval);
expsum += expv;
att_bth[t2] = expv;
}
float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;
// normalize to get the softmax
for (int t2 = 0; t2 < T; t2++) {
if (t2 <= t) {
att_bth[t2] *= expsum_inv;
} else {
// causal attention mask. not strictly necessary to set to zero here
// only doing this explicitly for debugging and checking to PyTorch
att_bth[t2] = 0.0f;
}
}
}
}
// warp-level reduction for finding the maximum value
__device__ float warpReduceMax(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val = fmaxf(val, __shfl_down_sync(0xFFFFFFFF, val, offset));
}
return val;
}
// warp-level reduction for summing values
__device__ float warpReduceSum(float val) {
for (int offset = 16; offset > 0; offset /= 2) {
val += __shfl_down_sync(0xFFFFFFFF, val, offset);
}
return val;
}
__global__ void softmax_forward_kernel4(float* out, float* inp, int N, int C) {
// out is (N, C) just like inp. Each row of inp will get softmaxed.
// same as kernel3, but can handle any block size (multiple of 32)
// each row of C elements is handled by block_size threads
// furthermore, each block_size threads get executed in warps of 32 threads
// special reduction operations warpReduceMax/warpReduceSum are used for intra-warp reductions
// shared memory is used for inter-warp reduction
extern __shared__ float shared[];
int idx = blockIdx.x;
int tid = threadIdx.x;
int warpId = threadIdx.x / 32; // warp index within a block
int laneId = threadIdx.x % 32; // thread index within a warp
// the number of warps per block. recall that blockDim.x is block_size
int warpsPerBlock = blockDim.x / 32;
// shared[] must be allocated to have 2 * warpsPerBlock elements
// first half for max values, the second half for sum values
float* maxvals = shared;
float* sumvals = &shared[warpsPerBlock];
// one row of inp, i.e. inp[idx, :] of shape (C,)
float* x = inp + idx * C;
// first, thread coarsening by directly accessing global memory in series
float maxval = -INFINITY;
for (int i = tid; i < C; i += blockDim.x) {
maxval = fmaxf(maxval, x[i]);
}
// now within-warp reductions for maxval
maxval = warpReduceMax(maxval);
// the 0th thread of each warp writes the maxval of that warp to shared memory
if (laneId == 0) maxvals[warpId] = maxval;
__syncthreads();
// now the 0th thread reduces the maxvals in shared memory, i.e. across warps
if (tid == 0) {
float val = maxvals[tid];
for (int i = 1; i < warpsPerBlock; i++) {
val = fmaxf(val, maxvals[i]);
}
// store the final max in the first position
maxvals[0] = val;
}
__syncthreads();
// broadcast the max to all threads
float offset = maxvals[0];
// compute expf and write the result to global memory
for (int i = tid; i < C; i += blockDim.x) {
// subtract max for numerical stability
out[idx * C + i] = expf(x[i] - offset);
}
// okay now we calculated exp(x - max(x))
// step 2: sum all the values and divide by the sum
// thread coarsening for sum
x = out + idx * C;
float sumval = 0.0f;
for (int i = tid; i < C; i += blockDim.x) {
sumval += x[i];
}
// within-warp reduction for sumval
sumval = warpReduceSum(sumval);
// write sumval to shared memory
if (laneId == 0) sumvals[warpId] = sumval;
__syncthreads();
// inter-thread reduction of sum
if (tid == 0) {
float val = sumvals[tid];
for (int i = 1; i < warpsPerBlock; ++i) {
val += sumvals[i];
}
sumvals[0] = val;
}
__syncthreads();
// broadcast the sum to all threads
float sum = sumvals[0];
// divide the whole row by the sum
for (int i = tid; i < C; i += blockDim.x) {
out[idx * C + i] = x[i] / sum;
}
}
__global__ void attention_value_kernel1(float* out, float* att, float* inp,
int B, int T, int C, int NH) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total_threads = B * T * NH;
if (idx < total_threads) {
int h = idx % NH;
int t = (idx / NH) % T;
int b = idx / (NH * T);
int C3 = C*3;
int hs = C / NH; // head size
float* out_bth = out + b * T * C + t * C + h * hs;
float* att_bth = att + b*NH*T*T + h*T*T + t*T;
for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }
for (int t2 = 0; t2 <= t; t2++) {
float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value
float att_btht2 = att_bth[t2];
for (int i = 0; i < hs; i++) {
out_bth[i] += att_btht2 * value_t2[i];
}
}
}
}
__global__
void attention_forward_kernel2(
const float* Q,
const float* K,
const float* V,
const int N,
const int d,
const int Tc,
const int Tr,
const int Bc,
const int Br,
const float softmax_scale,
float* l,
float* m,
float* O
) {
int tx = threadIdx.x;
int bx = blockIdx.x; int by = blockIdx.y; // batch and head index
// Offset into Q,K,V,O,l,m - different for each batch and head
int qkv_offset = (bx * gridDim.y * N * d) + (by * N * d); // gridDim.y = nh
int lm_offset = (bx * gridDim.y * N) + (by * N); // offset for l and m
// Define SRAM for Q,K,V,S
extern __shared__ float sram[];
int tile_size = Bc * d; // size of Qi, Kj, Vj
float* Qi = sram;
float* Kj = &sram[tile_size];
float* Vj = &sram[tile_size * 2];
float* S = &sram[tile_size * 3];
for (int j = 0; j < Tc; j++) {
// Load Kj, Vj to SRAM
for (int x = 0; x < d; x++) {
Kj[(tx * d) + x] = K[qkv_offset + (tile_size * j) + (tx * d) + x];
Vj[(tx * d) + x] = V[qkv_offset + (tile_size * j) + (tx * d) + x];
}
__syncthreads(); // such that the inner loop can use the correct Kj, Vj
for (int i = 0; i < Tr; i++) {
// if past the end of the sequence, break
if (i * Br + tx >= N) {
break;
}
// Load Qi to SRAM, l and m to registers
for (int x = 0; x < d; x++) {
Qi[(tx * d) + x] = Q[qkv_offset + (tile_size * i) + (tx * d) + x];
}
float row_m_prev = m[lm_offset + (Br * i) + tx];
float row_l_prev = l[lm_offset + (Br * i) + tx];
// S = QK^T, row_m = rowmax(S)
// S[tx][y] = Sum_{x = 0}^{d-1} {Qi[tx][x] * Kj[y][x]}
// row_m = Max_{y = 0}^{Bc-1} S[tx][y]
// with causal masking
float row_m = -INFINITY;
for (int y = 0; y < Bc; y++) {
if (j * Bc + y >= N) {
break;
}
float sum = 0;
for (int x = 0; x < d; x++) {
sum += Qi[(tx * d) + x] * Kj[(y * d) + x];
}
sum *= softmax_scale;
if (i * Br + tx < j * Bc + y)
sum = -INFINITY;
S[(Bc * tx) + y] = sum;
if (sum > row_m)
row_m = sum;
}
// implement softmax with causal masking
// P = exp(S - row_m), row_l = rowsum(P)
// P[tx][y] = exp(S[tx][y] - row_m)
float row_l = 0;
for (int y = 0; y < Bc; y++) {
if (j * Bc + y >= N) {
break;
}
if (i * Br + tx < j * Bc + y)
S[(Bc * tx) + y] = 0;
else
S[(Bc * tx) + y] = __expf(S[(Bc * tx) + y] - row_m);
row_l += S[(Bc * tx) + y];
}
// Compute new m and l
float row_m_new = max(row_m_prev, row_m);
float row_l_new = (__expf(row_m_prev - row_m_new) * row_l_prev) + (__expf(row_m - row_m_new) * row_l);
// Write O, l, m to HBM
for (int x = 0; x < d; x++) {
float pv = 0; // Pij * Vj
for (int y = 0; y < Bc; y++) {
if (j * Bc + y >= N) {
break;
}
pv += S[(Bc * tx) + y] * Vj[(y * d) + x];
}
O[qkv_offset + (tile_size * i) + (tx * d) + x] = (1 / row_l_new) \
* ((row_l_prev * __expf(row_m_prev - row_m_new) * O[qkv_offset + (tile_size * i) + (tx * d) + x]) \
+ (__expf(row_m - row_m_new) * pv));
}
m[lm_offset + (Br * i) + tx] = row_m_new;
l[lm_offset + (Br * i) + tx] = row_l_new;
}
__syncthreads(); // otherwise, thread can use the wrong Kj, Vj in inner loop
}
}
__global__ void permute_kernel(float* q, float* k, float* v,
const float* inp,
int B, int N, int NH, int d) {
// okay so now, this kernel wants Q,K,V to all be of shape (B, NH, N, d)
// but instead, we have a single tensor QKV (inp) of shape (B, N, 3, NH, d)
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// Q[b][nh_][n][d_] = inp[b][n][0][nh_][d_]
if (idx < B * NH * N * d) {
int b = idx / (NH * N * d);
int rest = idx % (NH * N * d);
int nh_ = rest / (N * d);
rest = rest % (N * d);
int n = rest / d;
int d_ = rest % d;
int inp_idx = \
(b * N * 3 * NH * d)
+ (n * 3 * NH * d)
+ (0 * NH * d)
+ (nh_ * d)
+ d_;
q[idx] = inp[inp_idx];
k[idx] = inp[inp_idx + NH * d];
v[idx] = inp[inp_idx + 2 * (NH * d)];
}
}
__global__ void unpermute_kernel(float* inp, float *out, int B, int N, int NH, int d) {
// out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)
int idx = blockIdx.x * blockDim.x + threadIdx.x;
// out[b][n][nh_][d_] <- inp[b][nh_][n][d_]
if (idx < B * NH * N * d) {
int b = idx / (NH * N * d);
int rest = idx % (NH * N * d);
int nh_ = rest / (N * d);
rest = rest % (N * d);
int n = rest / d;
int d_ = rest % d;
int other_idx = (b * NH * N * d) + (n * NH * d) + (nh_ * d) + d_;
out[other_idx] = inp[idx];
}
}
__global__ void scale_kernel(float* inp, float scale, int B, int NH, int T) {
// scales the pre-softmax attention scores by scale
// and sets the autoregressive locations to -INFINITY
int idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx < B * NH * T * T) {
int rest = idx % (NH * T * T);
rest = rest % (T * T);
int t2 = rest / T;
int t = rest % T;
if (t > t2) {
inp[idx] = -INFINITY;
} else {
inp[idx] *= scale;
}
}
}
// ----------------------------------------------------------------------------
// kernel launcher
void attention_forward1(float* out, float* preatt, float* att,
float* inp,
int B, int T, int C, int NH,
const int block_size) {
// attention calculation
int total_threads = B * NH * T * T;
int num_blocks = CEIL_DIV(total_threads, block_size);
attention_query_key_kernel1<<<num_blocks, block_size>>>(preatt, inp, B, T, C, NH);
// softmax and value accumulation
total_threads = B * T * NH;
num_blocks = CEIL_DIV(total_threads, block_size);
attention_softmax_kernel1<<<num_blocks, block_size>>>(att, preatt, B, T, NH);
attention_value_kernel1<<<num_blocks, block_size>>>(out, att, inp, B, T, C, NH);
}
void attention_forward2(float* out,
float* inp,
int B, int T, int C, int NH,
const int block_size) {
// TODO there should be no mallocs inside any of these functions!
// not fixing this because we don't intend to use attention_forward2,
// it seems to be way too slow as is
// these are hardcoded to 32 for now
const int Bc = 32;
const int Br = 32;
// renaming these to be consistent with the kernel
// const int B = B;
const int nh = NH;
const int N = T;
const int d = C / NH;
// more
const int Tc = ceil((float) N / Bc);
const int Tr = ceil((float) N / Br);
const float softmax_scale = 1.0 / sqrt(d);
// create some temporary memory
float* l;
float* m;
cudaCheck(cudaMalloc(&l, B * nh * N * sizeof(float)));
cudaCheck(cudaMalloc(&m, B * nh * N * sizeof(float)));
cudaCheck(cudaMemset(l, 0, B * nh * N * sizeof(float)));
cudaCheck(cudaMemset(m, -10000.0f, B * nh * N * sizeof(float)));
// calculate SRAM size needed per block, ensure we have enough shared memory
int col_tile_size = Bc * d; // size of Kj, Vj
int row_tile_size = Br * d; // size of Qi
const int sram_size =
(2 * col_tile_size * sizeof(float)) // SRAM size for Kj, Vj
+ (row_tile_size * sizeof(float)) // SRAM size for Qi
+ (Bc * Br * sizeof(float)); // SRAM size for S
int max_sram_size;
cudaDeviceGetAttribute(&max_sram_size, cudaDevAttrMaxSharedMemoryPerBlock, 0);
if (sram_size > max_sram_size) {
printf("Max shared memory: %d, requested shared memory: %d \n", max_sram_size, sram_size);
printf("SRAM size exceeds maximum shared memory per block\n");
printf("Try decreasing col_tile_size or row_tile_size further\n");
exit(1);
}
// grid and block dims
dim3 grid_dim(B, nh); // batch_size x num_heads
dim3 block_dim(Br); // Br threads per block
// okay so now, this kernel wants Q,K,V to all be of shape (B, nh, N, d)
// but instead, we have a single tensor QKV (inp) of shape (B, N, 3, nh, d)
// so we have to permute the tensor using a kernel with block_size
float *q, *k, *v;
cudaCheck(cudaMalloc(&q, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&k, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&v, B * T * C * sizeof(float)));
int total_threads = B * N * nh * d;
int num_blocks = CEIL_DIV(total_threads, block_size);
permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, N, nh, d);
// now actually call the flash attention kernel
attention_forward_kernel2<<<grid_dim, block_dim, sram_size>>>(
q, k, v,
N, d, Tc, Tr, Bc, Br, softmax_scale,
l, m, out
);
// out has shape (B, nh, N, d) but we need to unpermute it to (B, N, nh, d)
unpermute_kernel<<<num_blocks, block_size>>>(out, q, B, N, nh, d);
cudaCheck(cudaMemcpy(out, q, B * T * C * sizeof(float), cudaMemcpyDeviceToDevice));
// free memory
cudaCheck(cudaFree(l));
cudaCheck(cudaFree(m));
cudaCheck(cudaFree(q));
cudaCheck(cudaFree(k));
cudaCheck(cudaFree(v));
}
void attention_forward3(float* out, float* vaccum, float* qkvr, float* preatt, float* att,
float* inp,
int B, int T, int C, int NH,
const int block_size) {
// inp is (B, T, 3C) QKV
// preatt, att are (B, NH, T, T)
// output is (B, T, C)
int HS = C / NH; // head size
// permute and separate inp from (B, T, 3, NH, HS) to 3X (B, NH, T, HS)
float *q, *k, *v;
q = qkvr + 0 * B * T * C;
k = qkvr + 1 * B * T * C;
v = qkvr + 2 * B * T * C;
int total_threads = B * NH * T * HS;
int num_blocks = CEIL_DIV(total_threads, block_size);
permute_kernel<<<num_blocks, block_size>>>(q, k, v, inp, B, T, NH, HS);
// batched matrix multiply with cuBLAS
cublasHandle_t handle;
cublasStatus_t stat = cublasCreate(&handle);
const float alpha = 1.0f;
const float beta = 0.0f;
stat = cublasSgemmStridedBatched(handle,
CUBLAS_OP_T, CUBLAS_OP_N,
T, T, HS,
&alpha,
k, HS, T * HS,
q, HS, T * HS,
&beta,
preatt, T, T * T,
B * NH);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf("cublasSgemm failed\n");
exit(1);
}
// multiply all elements of preatt elementwise by scale
float scale = 1.0 / sqrtf(HS);
total_threads = B * NH * T * T;
num_blocks = CEIL_DIV(total_threads, block_size);
scale_kernel<<<num_blocks, block_size>>>(preatt, scale, B, NH, T);
// softmax. preatt is (B, NH, T, T) but we view it as (B * NH * T, T) and use the softmax kernel
int softmax_block_size = 256;
int grid_size = B * NH * T;
size_t shared_mem_size = 2 * softmax_block_size / 32 * sizeof(float);
softmax_forward_kernel4<<<grid_size, softmax_block_size, shared_mem_size>>>(att, preatt, B * NH * T, T);
// new approach: first cuBLAS another batched matmul
// y = att @ v # (B, nh, T, T) @ (B, nh, T, hs) -> (B, nh, T, hs)
stat = cublasSgemmStridedBatched(handle,
CUBLAS_OP_N, CUBLAS_OP_N,
HS, T, T,
&alpha,
v, HS, T * HS,
att, T, T * T,
&beta,
vaccum, HS, T * HS,
B * NH);
if (stat != CUBLAS_STATUS_SUCCESS) {
printf("cublasSgemm failed\n");
exit(1);
}
// now unpermute
// y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
num_blocks = CEIL_DIV(B * T * C, block_size);
unpermute_kernel<<<num_blocks, block_size>>>(vaccum, out, B, T, NH, HS);
// cleanups
cublasDestroy(handle);
}
// kernel version dispatch
void attention_forward(int kernel_num,
float* out, float* vaccum, float* qkvr, float* preatt, float* att,
float* inp,
int B, int T, int C, int NH,
const int block_size) {
switch (kernel_num) {
case 1:
attention_forward1(out, preatt, att, inp, B, T, C, NH, block_size);
break;
case 2:
attention_forward2(out, inp, B, T, C, NH, block_size);
break;
case 3:
attention_forward3(out, vaccum, qkvr, preatt, att, inp, B, T, C, NH, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
}
}
// ----------------------------------------------------------------------------
// random utils
float* make_random_float(int N) {
float* arr = (float*)malloc(N * sizeof(float));
for (int i = 0; i < N; i++) {
arr[i] = ((float)rand() / RAND_MAX) * 2.0 - 1.0;
}
return arr;
}
// ----------------------------------------------------------------------------
int main(int argc, char **argv) {
srand(0);
int B = 8;
int T = 1024;
int C = 768;
int NH = 12;
int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));
// create host memory of random numbers
float* out = (float*)malloc(B * T * C * sizeof(float));
float* preatt = (float*)malloc(B * NH * T * T * sizeof(float));
float* att = (float*)malloc(B * NH * T * T * sizeof(float));
float* inp = make_random_float(B * T * 3 * C);
// move to GPU
float* d_out;
float* d_vaccum;
float* d_qkvr;
float* d_preatt;
float* d_att;
float* d_inp;
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_vaccum, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_qkvr, B * T * 3 * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_preatt, B * NH * T * T * sizeof(float)));
cudaCheck(cudaMalloc(&d_att, B * NH * T * T * sizeof(float)));
cudaCheck(cudaMalloc(&d_inp, B * T * 3 * C * sizeof(float)));
cudaCheck(cudaMemcpy(d_inp, inp, B * T * 3 * C * sizeof(float), cudaMemcpyHostToDevice));
// read kernel_num from command line
int kernel_num = 1;
if (argc > 1) {
kernel_num = atoi(argv[1]);
}
printf("Using kernel %d\n", kernel_num);
// first check the correctness of the kernel
attention_forward_cpu(out, preatt, att, inp, B, T, C, NH);
attention_forward(kernel_num, d_out, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, 256);
// compare the output
float* out_gpu = (float*)malloc(B * T * C * sizeof(float));
cudaCheck(cudaMemcpy(out_gpu, d_out, B * T * C * sizeof(float), cudaMemcpyDeviceToHost));
for (int i = 0; i < B * T * C; i++) {
// print the first few comparisons
if (i < 5) {
printf("%f %f\n", out[i], out_gpu[i]);
}
// ensure correctness for all elements
if (fabs(out[i] - out_gpu[i]) > 1e-4) {
printf("Mismatch at %d: %f vs %f\n", i, out[i], out_gpu[i]);
exit(1);
}
}
printf("Results match!\n");
// time the kernel at different block sizes
int block_sizes[] = {32, 64, 128, 256, 512};
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];
int repeat_times = 10;
cudaEvent_t start, stop;
cudaCheck(cudaEventCreate(&start));
cudaCheck(cudaEventCreate(&stop));
cudaCheck(cudaEventRecord(start, 0));
for (int i = 0; i < repeat_times; i++) {
attention_forward(kernel_num, d_out, d_vaccum, d_qkvr, d_preatt, d_att, d_inp, B, T, C, NH, block_size);
}
cudaCheck(cudaEventRecord(stop, 0));
cudaCheck(cudaEventSynchronize(start));
cudaCheck(cudaEventSynchronize(stop));
float elapsed_time;
cudaCheck(cudaEventElapsedTime(&elapsed_time, start, stop));
printf("block_size %4d | time %f ms\n", block_size, elapsed_time);
}
// free memory
free(out);
free(preatt);
free(att);
free(inp);
free(out_gpu);
cudaCheck(cudaFree(d_out));
cudaCheck(cudaFree(d_vaccum));
cudaCheck(cudaFree(d_qkvr));
cudaCheck(cudaFree(d_preatt));
cudaCheck(cudaFree(d_att));
cudaCheck(cudaFree(d_inp));
return 0;
}