forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
naive attention kernel only parallelizing over batch,time,heads. have…
… to speed this up a lot
- Loading branch information
Showing
1 changed file
with
340 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,340 @@ | ||
/* | ||
Kernels for attention forward pass. | ||
Compile example: | ||
nvcc -O3 --use_fast_math attention_forward.cu -o attention_forward | ||
version 1 is naive port from CPU code to kernel, parallelize over batch, time, heads only | ||
./attention_forward 1 | ||
*/ | ||
|
||
#include <stdio.h> | ||
#include <stdlib.h> | ||
#include <cuda_runtime.h> | ||
|
||
// ---------------------------------------------------------------------------- | ||
// CUDA utils | ||
|
||
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) | ||
|
||
// error checking | ||
void cudaCheck(cudaError_t error, const char *file, int line) { | ||
if (error != cudaSuccess) { | ||
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, | ||
cudaGetErrorString(error)); | ||
exit(EXIT_FAILURE); | ||
} | ||
}; | ||
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__)) | ||
|
||
// ---------------------------------------------------------------------------- | ||
// CPU code reference | ||
|
||
void attention_forward_cpu(float* out, float* preatt, float* att, | ||
float* inp, | ||
int B, int T, int C, int NH) { | ||
// input is (B, T, 3C) Q,K,V | ||
// preatt, att are (B, NH, T, T) | ||
// output is (B, T, C) | ||
int C3 = C*3; | ||
int hs = C / NH; // head size | ||
float scale = 1.0 / sqrtf(hs); | ||
|
||
for (int b = 0; b < B; b++) { | ||
for (int t = 0; t < T; t++) { | ||
for (int h = 0; h < NH; h++) { | ||
float* query_t = inp + b * T * C3 + t * C3 + h * hs; | ||
float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T; | ||
float* att_bth = att + b*NH*T*T + h*T*T + t*T; | ||
|
||
// pass 1: calculate query dot key and maxval | ||
float maxval = -10000.0f; // TODO something better | ||
for (int t2 = 0; t2 <= t; t2++) { | ||
float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key | ||
|
||
// (query_t) dot (key_t2) | ||
float val = 0.0f; | ||
for (int i = 0; i < hs; i++) { | ||
val += query_t[i] * key_t2[i]; | ||
} | ||
val *= scale; | ||
if (val > maxval) { | ||
maxval = val; | ||
} | ||
|
||
preatt_bth[t2] = val; | ||
} | ||
|
||
// pass 2: calculate the exp and keep track of sum | ||
float expsum = 0.0f; | ||
for (int t2 = 0; t2 <= t; t2++) { | ||
float expv = expf(preatt_bth[t2] - maxval); | ||
expsum += expv; | ||
att_bth[t2] = expv; | ||
} | ||
float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum; | ||
|
||
// pass 3: normalize to get the softmax | ||
for (int t2 = 0; t2 < T; t2++) { | ||
if (t2 <= t) { | ||
att_bth[t2] *= expsum_inv; | ||
} else { | ||
// causal attention mask. not strictly necessary to set to zero here | ||
// only doing this explicitly for debugging and checking to PyTorch | ||
att_bth[t2] = 0.0f; | ||
} | ||
} | ||
|
||
// pass 4: accumulate weighted values into the output of attention | ||
float* out_bth = out + b * T * C + t * C + h * hs; | ||
for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; } | ||
for (int t2 = 0; t2 <= t; t2++) { | ||
float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value | ||
float att_btht2 = att_bth[t2]; | ||
for (int i = 0; i < hs; i++) { | ||
out_bth[i] += att_btht2 * value_t2[i]; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
// ---------------------------------------------------------------------------- | ||
// GPU kernels | ||
|
||
__global__ void attention_query_key_kernel1(float* preatt, float* inp, | ||
int B, int T, int C, int NH) { | ||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
int total_threads = B * NH * T * T; | ||
|
||
if (idx < total_threads) { | ||
int t2 = idx % T; | ||
int t = (idx / T) % T; | ||
if (t2 > t) { return; } // autoregressive mask | ||
int h = (idx / (T * T)) % NH; | ||
int b = idx / (NH * T * T); | ||
|
||
int C3 = C*3; | ||
int hs = C / NH; // head size | ||
float* query_t = inp + b * T * C3 + t * C3 + h * hs; | ||
float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key | ||
|
||
// (query_t) dot (key_t2) | ||
float val = 0.0f; | ||
for (int i = 0; i < hs; i++) { | ||
val += query_t[i] * key_t2[i]; | ||
} | ||
val *= 1.0 / sqrtf(hs); | ||
|
||
preatt[idx] = val; | ||
} | ||
} | ||
|
||
__global__ void attention_softmax_kernel1(float* att, float* preatt, | ||
int B, int T, int NH) { | ||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
int total_threads = B * T * NH; | ||
|
||
if (idx < total_threads) { | ||
int h = idx % NH; | ||
int t = (idx / NH) % T; | ||
int b = idx / (NH * T); | ||
|
||
float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T; | ||
float* att_bth = att + b*NH*T*T + h*T*T + t*T; | ||
|
||
// find maxval | ||
float maxval = -10000.0f; // TODO something better | ||
for (int t2 = 0; t2 <= t; t2++) { | ||
if (preatt_bth[t2] > maxval) { | ||
maxval = preatt_bth[t2]; | ||
} | ||
} | ||
|
||
// calculate the exp and keep track of sum | ||
float expsum = 0.0f; | ||
for (int t2 = 0; t2 <= t; t2++) { | ||
float expv = expf(preatt_bth[t2] - maxval); | ||
expsum += expv; | ||
att_bth[t2] = expv; | ||
} | ||
float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum; | ||
|
||
// normalize to get the softmax | ||
for (int t2 = 0; t2 < T; t2++) { | ||
if (t2 <= t) { | ||
att_bth[t2] *= expsum_inv; | ||
} else { | ||
// causal attention mask. not strictly necessary to set to zero here | ||
// only doing this explicitly for debugging and checking to PyTorch | ||
att_bth[t2] = 0.0f; | ||
} | ||
} | ||
} | ||
} | ||
|
||
__global__ void attention_value_kernel1(float* out, float* att, float* inp, | ||
int B, int T, int C, int NH) { | ||
int idx = blockIdx.x * blockDim.x + threadIdx.x; | ||
int total_threads = B * T * NH; | ||
|
||
if (idx < total_threads) { | ||
int h = idx % NH; | ||
int t = (idx / NH) % T; | ||
int b = idx / (NH * T); | ||
|
||
int C3 = C*3; | ||
int hs = C / NH; // head size | ||
|
||
float* out_bth = out + b * T * C + t * C + h * hs; | ||
float* att_bth = att + b*NH*T*T + h*T*T + t*T; | ||
|
||
for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; } | ||
for (int t2 = 0; t2 <= t; t2++) { | ||
float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value | ||
float att_btht2 = att_bth[t2]; | ||
for (int i = 0; i < hs; i++) { | ||
out_bth[i] += att_btht2 * value_t2[i]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
// ---------------------------------------------------------------------------- | ||
// kernel launcher | ||
|
||
void attention_forward1(float* out, float* preatt, float* att, | ||
float* inp, | ||
int B, int T, int C, int NH, | ||
const int block_size) { | ||
// attention calculation | ||
int total_threads = B * NH * T * T; | ||
int num_blocks = CEIL_DIV(total_threads, block_size); | ||
attention_query_key_kernel1<<<num_blocks, block_size>>>(preatt, inp, B, T, C, NH); | ||
// softmax and value accumulation | ||
total_threads = B * T * NH; | ||
num_blocks = CEIL_DIV(total_threads, block_size); | ||
attention_softmax_kernel1<<<num_blocks, block_size>>>(att, preatt, B, T, NH); | ||
attention_value_kernel1<<<num_blocks, block_size>>>(out, att, inp, B, T, C, NH); | ||
} | ||
|
||
// kernel version dispatch | ||
void attention_forward(int kernel_num, | ||
float* out, float* preatt, float* att, | ||
float* inp, | ||
int B, int T, int C, int NH, | ||
const int block_size) { | ||
switch (kernel_num) { | ||
case 1: | ||
attention_forward1(out, preatt, att, inp, B, T, C, NH, block_size); | ||
break; | ||
default: | ||
printf("Invalid kernel number\n"); | ||
exit(1); | ||
} | ||
} | ||
|
||
// ---------------------------------------------------------------------------- | ||
// random utils | ||
|
||
float* make_random_float(int N) { | ||
float* arr = (float*)malloc(N * sizeof(float)); | ||
for (int i = 0; i < N; i++) { | ||
arr[i] = ((float)rand() / RAND_MAX) * 2.0 - 1.0; | ||
} | ||
return arr; | ||
} | ||
|
||
// ---------------------------------------------------------------------------- | ||
|
||
int main(int argc, char **argv) { | ||
srand(0); | ||
|
||
int B = 8; | ||
int T = 1024; | ||
int C = 768; | ||
int NH = 12; | ||
|
||
int deviceIdx = 0; | ||
cudaCheck(cudaSetDevice(deviceIdx)); | ||
|
||
// create host memory of random numbers | ||
float* out = (float*)malloc(B * T * C * sizeof(float)); | ||
float* preatt = (float*)malloc(B * NH * T * T * sizeof(float)); | ||
float* att = (float*)malloc(B * NH * T * T * sizeof(float)); | ||
float* inp = make_random_float(B * T * 3 * C); | ||
|
||
// move to GPU | ||
float* d_out; | ||
float* d_preatt; | ||
float* d_att; | ||
float* d_inp; | ||
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float))); | ||
cudaCheck(cudaMalloc(&d_preatt, B * NH * T * T * sizeof(float))); | ||
cudaCheck(cudaMalloc(&d_att, B * NH * T * T * sizeof(float))); | ||
cudaCheck(cudaMalloc(&d_inp, B * T * 3 * C * sizeof(float))); | ||
cudaCheck(cudaMemcpy(d_inp, inp, B * T * 3 * C * sizeof(float), cudaMemcpyHostToDevice)); | ||
|
||
// read kernel_num from command line | ||
int kernel_num = 1; | ||
if (argc > 1) { | ||
kernel_num = atoi(argv[1]); | ||
} | ||
printf("Using kernel %d\n", kernel_num); | ||
|
||
// first check the correctness of the kernel | ||
attention_forward_cpu(out, preatt, att, inp, B, T, C, NH); | ||
attention_forward(kernel_num, d_out, d_preatt, d_att, d_inp, B, T, C, NH, 256); | ||
|
||
float* out_gpu = (float*)malloc(B * T * C * sizeof(float)); | ||
cudaCheck(cudaMemcpy(out_gpu, d_out, B * T * C * sizeof(float), cudaMemcpyDeviceToHost)); | ||
for (int i = 0; i < B * T * C; i++) { | ||
// print the first few comparisons | ||
if (i < 5) { | ||
printf("%f %f\n", out[i], out_gpu[i]); | ||
} | ||
// ensure correctness for all elements | ||
if (fabs(out[i] - out_gpu[i]) > 1e-4) { | ||
printf("Mismatch at %d: %f vs %f\n", i, out[i], out_gpu[i]); | ||
exit(1); | ||
} | ||
} | ||
printf("Results match!\n"); | ||
|
||
// time the kernel at different block sizes | ||
int block_sizes[] = {32, 64, 128, 256, 512}; | ||
|
||
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { | ||
int block_size = block_sizes[j]; | ||
|
||
int repeat_times = 10; | ||
cudaEvent_t start, stop; | ||
cudaCheck(cudaEventCreate(&start)); | ||
cudaCheck(cudaEventCreate(&stop)); | ||
cudaCheck(cudaEventRecord(start, 0)); | ||
for (int i = 0; i < repeat_times; i++) { | ||
attention_forward(kernel_num, d_out, d_preatt, d_att, d_inp, B, T, C, NH, block_size); | ||
} | ||
cudaCheck(cudaEventRecord(stop, 0)); | ||
cudaCheck(cudaEventSynchronize(start)); | ||
cudaCheck(cudaEventSynchronize(stop)); | ||
float elapsed_time; | ||
cudaCheck(cudaEventElapsedTime(&elapsed_time, start, stop)); | ||
|
||
printf("block_size %4d | time %f ms\n", block_size, elapsed_time); | ||
} | ||
|
||
// free memory | ||
free(out); | ||
free(preatt); | ||
free(att); | ||
free(inp); | ||
free(out_gpu); | ||
cudaCheck(cudaFree(d_out)); | ||
cudaCheck(cudaFree(d_preatt)); | ||
cudaCheck(cudaFree(d_att)); | ||
cudaCheck(cudaFree(d_inp)); | ||
|
||
return 0; | ||
} |