forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
might as well push the few kernels that i feel ok about so far
- Loading branch information
Showing
6 changed files
with
1,248 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# dev/cuda | ||
|
||
This directory is scratch space for developing various versions of the needed CUDA kernels. Each file develops a kernel, see the top of each file for instructions on how to compile and run each one. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
/* | ||
Kernels for gelu forward pass. | ||
Compile example: | ||
nvcc -O3 --use_fast_math gelu_forward.cu -o gelu_forward | ||
version 1 is naive port from CPU code to kernel | ||
./gelu_forward 1 | ||
*/ | ||
|
||
#include <stdio.h> | ||
#include <stdlib.h> | ||
#include <cuda_runtime.h> | ||
|
||
// ---------------------------------------------------------------------------- | ||
// CUDA utils | ||
|
||
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) | ||
|
||
// error checking | ||
void cudaCheck(cudaError_t error, const char *file, int line) { | ||
if (error != cudaSuccess) { | ||
printf("[CUDA ERROR] at file %s:%d:\n%s\n", file, line, | ||
cudaGetErrorString(error)); | ||
exit(EXIT_FAILURE); | ||
} | ||
}; | ||
#define cudaCheck(err) (cudaCheck(err, __FILE__, __LINE__)) | ||
|
||
// ---------------------------------------------------------------------------- | ||
// CPU code reference | ||
|
||
void gelu_forward_cpu(float* out, float* inp, int N) { | ||
float s = sqrtf(2.0f / M_PI); | ||
for (int i = 0; i < N; i++) { | ||
float x = inp[i]; | ||
float cube = 0.044715f * x * x * x; | ||
out[i] = 0.5f * x * (1.0f + tanhf(s * (x + cube))); | ||
} | ||
} | ||
|
||
// ---------------------------------------------------------------------------- | ||
// GPU kernels | ||
|
||
// elementwise ops are nice and ez | ||
__global__ void gelu_kernel(float* out, const float* inp, int N) { | ||
int i = blockIdx.x * blockDim.x + threadIdx.x; | ||
float s = sqrtf(2.0f / M_PI); | ||
if (i < N) { | ||
float xi = inp[i]; | ||
float cube = 0.044715f * xi * xi * xi; | ||
out[i] = 0.5f * xi * (1.0f + tanhf(s * (xi + cube))); | ||
} | ||
} | ||
|
||
// ---------------------------------------------------------------------------- | ||
// kernel launcher | ||
|
||
void gelu_forward1(float* out, float* inp, int N, const int block_size) { | ||
const int grid_size = CEIL_DIV(N, block_size); | ||
gelu_kernel<<<grid_size, block_size>>>(out, inp, N); | ||
cudaCheck(cudaGetLastError()); | ||
} | ||
|
||
// kernel version dispatch | ||
void gelu_forward(int kernel_num, | ||
float* out, | ||
float* inp, | ||
int B, int T, int C, | ||
int block_size) { | ||
switch (kernel_num) { | ||
case 1: | ||
gelu_forward1(out, inp, B * T * C, block_size); | ||
break; | ||
default: | ||
printf("Invalid kernel number\n"); | ||
exit(1); | ||
} | ||
} | ||
|
||
// ---------------------------------------------------------------------------- | ||
// random utils | ||
|
||
float* make_random_float(int N) { | ||
float* arr = (float*)malloc(N * sizeof(float)); | ||
for (int i = 0; i < N; i++) { | ||
arr[i] = ((float)rand() / RAND_MAX) * 2.0 - 1.0; | ||
} | ||
return arr; | ||
} | ||
|
||
// ---------------------------------------------------------------------------- | ||
|
||
int main(int argc, char **argv) { | ||
srand(0); | ||
|
||
int B = 8; | ||
int T = 1024; | ||
int C = 768; | ||
|
||
int deviceIdx = 0; | ||
cudaCheck(cudaSetDevice(deviceIdx)); | ||
|
||
// create host memory of random numbers | ||
float* out = (float*)malloc(B * T * C * sizeof(float)); | ||
float* inp = make_random_float(B * T * C); | ||
|
||
// move to GPU | ||
float* d_out; | ||
float* d_inp; | ||
cudaCheck(cudaMalloc(&d_out, B * T * C * sizeof(float))); | ||
cudaCheck(cudaMalloc(&d_inp, B * T * C * sizeof(float))); | ||
cudaCheck(cudaMemcpy(d_inp, inp, B * T * C * sizeof(float), cudaMemcpyHostToDevice)); | ||
|
||
// read kernel_num from command line | ||
int kernel_num = 1; | ||
if (argc > 1) { | ||
kernel_num = atoi(argv[1]); | ||
} | ||
printf("Using kernel %d\n", kernel_num); | ||
|
||
// first check the correctness of the kernel | ||
gelu_forward_cpu(out, inp, B * T * C); | ||
gelu_forward(kernel_num, d_out, d_inp, B, T, C, 128); | ||
float* out_gpu = (float*)malloc(B * T * C * sizeof(float)); | ||
cudaCheck(cudaMemcpy(out_gpu, d_out, B * T * C * sizeof(float), cudaMemcpyDeviceToHost)); | ||
for (int i = 0; i < B * T * C; i++) { | ||
// print the first few comparisons | ||
if (i < 5) { | ||
printf("%f %f\n", out[i], out_gpu[i]); | ||
} | ||
// ensure correctness for all elements | ||
if (fabs(out[i] - out_gpu[i]) > 1e-5) { | ||
printf("Mismatch at %d: %f vs %f\n", i, out[i], out_gpu[i]); | ||
exit(1); | ||
} | ||
} | ||
printf("Results match!\n"); | ||
|
||
// time the kernel at different block sizes | ||
int block_sizes[] = {32, 64, 128, 256, 512, 1024}; | ||
|
||
for (int j = 0; j < sizeof(block_sizes) / sizeof(int); j++) { | ||
int block_size = block_sizes[j]; | ||
|
||
int repeat_times = 1000; | ||
cudaEvent_t start, stop; | ||
cudaCheck(cudaEventCreate(&start)); | ||
cudaCheck(cudaEventCreate(&stop)); | ||
cudaCheck(cudaEventRecord(start, 0)); | ||
for (int i = 0; i < repeat_times; i++) { | ||
gelu_forward(kernel_num, d_out, d_inp, B, T, C, block_size); | ||
} | ||
cudaCheck(cudaEventRecord(stop, 0)); | ||
cudaCheck(cudaEventSynchronize(start)); | ||
cudaCheck(cudaEventSynchronize(stop)); | ||
float elapsed_time; | ||
cudaCheck(cudaEventElapsedTime(&elapsed_time, start, stop)); | ||
|
||
// napkin math: estimate the memory bandwidth achieved | ||
// for each (B,T,C) output element, we do 1 read and 1 write, 4 bytes each | ||
// and e.g. A100 40GB PCIe is advertised at 1,555GB/s | ||
long memory_ops = B * T * C * 2 * 4; | ||
float memory_bandwidth = memory_ops / (elapsed_time / repeat_times) / 1e6; | ||
|
||
printf("block_size %4d | time %f ms | bandwidth %f GB/s\n", block_size, elapsed_time / repeat_times, memory_bandwidth); | ||
} | ||
|
||
// free memory | ||
free(out); | ||
free(inp); | ||
cudaCheck(cudaFree(d_out)); | ||
cudaCheck(cudaFree(d_inp)); | ||
|
||
return 0; | ||
} |
Oops, something went wrong.