Skip to content

Commit

Permalink
naive attention kernel only parallelizing over batch,time,heads. have…
Browse files Browse the repository at this point in the history
… to speed this up a lot
  • Loading branch information
karpathy committed Apr 9, 2024
1 parent 41d5f56 commit 03f37cf
Showing 1 changed file with 340 additions and 0 deletions.
340 changes: 340 additions & 0 deletions dev/cuda/attention_forward.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,340 @@
/*
Kernels for attention forward pass.
Compile example:
nvcc -O3 --use_fast_math attention_forward.cu -o attention_forward
version 1 is naive port from CPU code to kernel, parallelize over batch, time, heads only
./attention_forward 1
*/

#include <stdio.h>
#include <stdlib.h>
#include <cuda_runtime.h>

// ----------------------------------------------------------------------------
// CUDA utils

#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))

// error checking
void cudaCheck(cudaError_t error, const char *file, int line) {
if (error != cudaSuccess) {
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line,
cudaGetErrorString(error));
exit(EXIT_FAILURE);
}
};
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__))

// ----------------------------------------------------------------------------
// CPU code reference

void attention_forward_cpu(float* out, float* preatt, float* att,
float* inp,
int B, int T, int C, int NH) {
// input is (B, T, 3C) Q,K,V
// preatt, att are (B, NH, T, T)
// output is (B, T, C)
int C3 = C*3;
int hs = C / NH; // head size
float scale = 1.0 / sqrtf(hs);

for (int b = 0; b < B; b++) {
for (int t = 0; t < T; t++) {
for (int h = 0; h < NH; h++) {
float* query_t = inp + b * T * C3 + t * C3 + h * hs;
float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;
float* att_bth = att + b*NH*T*T + h*T*T + t*T;

// pass 1: calculate query dot key and maxval
float maxval = -10000.0f; // TODO something better
for (int t2 = 0; t2 <= t; t2++) {
float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key

// (query_t) dot (key_t2)
float val = 0.0f;
for (int i = 0; i < hs; i++) {
val += query_t[i] * key_t2[i];
}
val *= scale;
if (val > maxval) {
maxval = val;
}

preatt_bth[t2] = val;
}

// pass 2: calculate the exp and keep track of sum
float expsum = 0.0f;
for (int t2 = 0; t2 <= t; t2++) {
float expv = expf(preatt_bth[t2] - maxval);
expsum += expv;
att_bth[t2] = expv;
}
float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;

// pass 3: normalize to get the softmax
for (int t2 = 0; t2 < T; t2++) {
if (t2 <= t) {
att_bth[t2] *= expsum_inv;
} else {
// causal attention mask. not strictly necessary to set to zero here
// only doing this explicitly for debugging and checking to PyTorch
att_bth[t2] = 0.0f;
}
}

// pass 4: accumulate weighted values into the output of attention
float* out_bth = out + b * T * C + t * C + h * hs;
for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }
for (int t2 = 0; t2 <= t; t2++) {
float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value
float att_btht2 = att_bth[t2];
for (int i = 0; i < hs; i++) {
out_bth[i] += att_btht2 * value_t2[i];
}
}
}
}
}
}

// ----------------------------------------------------------------------------
// GPU kernels

__global__ void attention_query_key_kernel1(float* preatt, float* inp,
int B, int T, int C, int NH) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total_threads = B * NH * T * T;

if (idx < total_threads) {
int t2 = idx % T;
int t = (idx / T) % T;
if (t2 > t) { return; } // autoregressive mask
int h = (idx / (T * T)) % NH;
int b = idx / (NH * T * T);

int C3 = C*3;
int hs = C / NH; // head size
float* query_t = inp + b * T * C3 + t * C3 + h * hs;
float* key_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C; // +C because it's key

// (query_t) dot (key_t2)
float val = 0.0f;
for (int i = 0; i < hs; i++) {
val += query_t[i] * key_t2[i];
}
val *= 1.0 / sqrtf(hs);

preatt[idx] = val;
}
}

__global__ void attention_softmax_kernel1(float* att, float* preatt,
int B, int T, int NH) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total_threads = B * T * NH;

if (idx < total_threads) {
int h = idx % NH;
int t = (idx / NH) % T;
int b = idx / (NH * T);

float* preatt_bth = preatt + b*NH*T*T + h*T*T + t*T;
float* att_bth = att + b*NH*T*T + h*T*T + t*T;

// find maxval
float maxval = -10000.0f; // TODO something better
for (int t2 = 0; t2 <= t; t2++) {
if (preatt_bth[t2] > maxval) {
maxval = preatt_bth[t2];
}
}

// calculate the exp and keep track of sum
float expsum = 0.0f;
for (int t2 = 0; t2 <= t; t2++) {
float expv = expf(preatt_bth[t2] - maxval);
expsum += expv;
att_bth[t2] = expv;
}
float expsum_inv = expsum == 0.0f ? 0.0f : 1.0f / expsum;

// normalize to get the softmax
for (int t2 = 0; t2 < T; t2++) {
if (t2 <= t) {
att_bth[t2] *= expsum_inv;
} else {
// causal attention mask. not strictly necessary to set to zero here
// only doing this explicitly for debugging and checking to PyTorch
att_bth[t2] = 0.0f;
}
}
}
}

__global__ void attention_value_kernel1(float* out, float* att, float* inp,
int B, int T, int C, int NH) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
int total_threads = B * T * NH;

if (idx < total_threads) {
int h = idx % NH;
int t = (idx / NH) % T;
int b = idx / (NH * T);

int C3 = C*3;
int hs = C / NH; // head size

float* out_bth = out + b * T * C + t * C + h * hs;
float* att_bth = att + b*NH*T*T + h*T*T + t*T;

for (int i = 0; i < hs; i++) { out_bth[i] = 0.0f; }
for (int t2 = 0; t2 <= t; t2++) {
float* value_t2 = inp + b * T * C3 + t2 * C3 + h * hs + C*2; // +C*2 because it's value
float att_btht2 = att_bth[t2];
for (int i = 0; i < hs; i++) {
out_bth[i] += att_btht2 * value_t2[i];
}
}
}
}

// ----------------------------------------------------------------------------
// kernel launcher

void attention_forward1(float* out, float* preatt, float* att,
float* inp,
int B, int T, int C, int NH,
const int block_size) {
// attention calculation
int total_threads = B * NH * T * T;
int num_blocks = CEIL_DIV(total_threads, block_size);
attention_query_key_kernel1<<<num_blocks, block_size>>>(preatt, inp, B, T, C, NH);
// softmax and value accumulation
total_threads = B * T * NH;
num_blocks = CEIL_DIV(total_threads, block_size);
attention_softmax_kernel1<<<num_blocks, block_size>>>(att, preatt, B, T, NH);
attention_value_kernel1<<<num_blocks, block_size>>>(out, att, inp, B, T, C, NH);
}

// kernel version dispatch
void attention_forward(int kernel_num,
float* out, float* preatt, float* att,
float* inp,
int B, int T, int C, int NH,
const int block_size) {
switch (kernel_num) {
case 1:
attention_forward1(out, preatt, att, inp, B, T, C, NH, block_size);
break;
default:
printf("Invalid kernel number\n");
exit(1);
}
}

// ----------------------------------------------------------------------------
// random utils

float* make_random_float(int N) {
float* arr = (float*)malloc(N * sizeof(float));
for (int i = 0; i < N; i++) {
arr[i] = ((float)rand() / RAND_MAX) * 2.0 - 1.0;
}
return arr;
}

// ----------------------------------------------------------------------------

int main(int argc, char **argv) {
srand(0);

int B = 8;
int T = 1024;
int C = 768;
int NH = 12;

int deviceIdx = 0;
cudaCheck(cudaSetDevice(deviceIdx));

// create host memory of random numbers
float* out = (float*)malloc(B * T * C * sizeof(float));
float* preatt = (float*)malloc(B * NH * T * T * sizeof(float));
float* att = (float*)malloc(B * NH * T * T * sizeof(float));
float* inp = make_random_float(B * T * 3 * C);

// move to GPU
float* d_out;
float* d_preatt;
float* d_att;
float* d_inp;
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float)));
cudaCheck(cudaMalloc(&d_preatt, B * NH * T * T * sizeof(float)));
cudaCheck(cudaMalloc(&d_att, B * NH * T * T * sizeof(float)));
cudaCheck(cudaMalloc(&d_inp, B * T * 3 * C * sizeof(float)));
cudaCheck(cudaMemcpy(d_inp, inp, B * T * 3 * C * sizeof(float), cudaMemcpyHostToDevice));

// read kernel_num from command line
int kernel_num = 1;
if (argc > 1) {
kernel_num = atoi(argv[1]);
}
printf("Using kernel %d\n", kernel_num);

// first check the correctness of the kernel
attention_forward_cpu(out, preatt, att, inp, B, T, C, NH);
attention_forward(kernel_num, d_out, d_preatt, d_att, d_inp, B, T, C, NH, 256);

float* out_gpu = (float*)malloc(B * T * C * sizeof(float));
cudaCheck(cudaMemcpy(out_gpu, d_out, B * T * C * sizeof(float), cudaMemcpyDeviceToHost));
for (int i = 0; i < B * T * C; i++) {
// print the first few comparisons
if (i < 5) {
printf("%f %f\n", out[i], out_gpu[i]);
}
// ensure correctness for all elements
if (fabs(out[i] - out_gpu[i]) > 1e-4) {
printf("Mismatch at %d: %f vs %f\n", i, out[i], out_gpu[i]);
exit(1);
}
}
printf("Results match!\n");

// time the kernel at different block sizes
int block_sizes[] = {32, 64, 128, 256, 512};

for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) {
int block_size = block_sizes[j];

int repeat_times = 10;
cudaEvent_t start, stop;
cudaCheck(cudaEventCreate(&start));
cudaCheck(cudaEventCreate(&stop));
cudaCheck(cudaEventRecord(start, 0));
for (int i = 0; i < repeat_times; i++) {
attention_forward(kernel_num, d_out, d_preatt, d_att, d_inp, B, T, C, NH, block_size);
}
cudaCheck(cudaEventRecord(stop, 0));
cudaCheck(cudaEventSynchronize(start));
cudaCheck(cudaEventSynchronize(stop));
float elapsed_time;
cudaCheck(cudaEventElapsedTime(&elapsed_time, start, stop));

printf("block_size %4d | time %f ms\n", block_size, elapsed_time);
}

// free memory
free(out);
free(preatt);
free(att);
free(inp);
free(out_gpu);
cudaCheck(cudaFree(d_out));
cudaCheck(cudaFree(d_preatt));
cudaCheck(cudaFree(d_att));
cudaCheck(cudaFree(d_inp));

return 0;
}

0 comments on commit 03f37cf

Please sign in to comment.