Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
Aaryan0404 committed May 13, 2024
2 parents fd32619 + 6884518 commit a562ed2
Show file tree
Hide file tree
Showing 8 changed files with 179 additions and 488 deletions.
129 changes: 87 additions & 42 deletions README.md

Large diffs are not rendered by default.

Binary file added attn.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
60 changes: 35 additions & 25 deletions examples/attn/4090/4090_ker.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
#include "../../../src/kittens.cuh"

#define NUM_WORKERS 16
using namespace kittens;
__global__ void attend_ker(int n, int d, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) {
// this kernel is more of an example kernel to show some TK programming models, rather than a kernel we think you should put into production, though it is pretty fast!

#define NUM_WORKERS 16 // This kernel uses 16 workers in parallel per block, to help issue instructions more quickly.

using namespace kittens; // this kernel only handles headdim=64 for simplicity. Also n should be a multiple of 256 here.
__global__ void attend_ker64(int n, const bf16* __restrict__ __q__, const bf16* __restrict__ __k__, const bf16* __restrict__ __v__, bf16* __o__) {

auto warpid = kittens::warpid();
auto block_start = blockIdx.x*(n*64);
Expand All @@ -12,70 +15,77 @@ __global__ void attend_ker(int n, int d, const bf16* __restrict__ __q__, const b
extern __shared__ alignment_dummy __shm[]; // this is the CUDA shared memory
shared_allocator al((int*)&__shm[0]);

// K and V live in shared memory -- this is about all that will fit.
st_bf_1x4<ducks::st_layout::swizzle> (&k_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();
st_bf_1x4<ducks::st_layout::swizzle> (&v_smem)[NUM_WORKERS] = al.allocate<st_bf_1x4<ducks::st_layout::swizzle>, NUM_WORKERS>();

rt_bf_1x4<> q_reg, k_reg, v_reg;
// Initialize all of the register tiles.
rt_bf_1x4<> q_reg, k_reg, v_reg; // v_reg need to be swapped into col_l
rt_fl_1x1<> att_block;
rt_bf_1x1<> att_block_mma;
rt_fl_1x4<> o_prev;
rt_fl_1x1<>::col_vec max_vec_last, max_vec;
rt_fl_1x1<>::col_vec norm_vec_last, norm_vec;
rt_fl_1x4<> o_reg;
rt_fl_1x1<>::col_vec max_vec_last, max_vec; // these are column vectors for the attention block
rt_fl_1x1<>::col_vec norm_vec_last, norm_vec; // these are column vectors for the attention block

int qo_blocks = n / (q_reg.rows*NUM_WORKERS), kv_blocks = n / (q_reg.rows*NUM_WORKERS);

for(auto q_blk = 0; q_blk < qo_blocks; q_blk++) {

// each warp loads its own Q tile of 16x64, and then multiplies by 1/sqrt(d)
load(q_reg, _q + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
mul(q_reg, q_reg, __float2bfloat16(0.125f)); // temperature adjustment

// zero flash attention L, M, and O registers.
neg_infty(max_vec); // zero registers for the Q chunk
zero(norm_vec);
zero(o_prev);
zero(o_reg);

// iterate over k, v for these q's that have been loaded
for(auto kv_idx = 0; kv_idx < kv_blocks; kv_idx++) {

// each warp loads its own chunk of k, v into shared memory
load(v_smem[warpid], _v + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
load(k_smem[warpid], _k + (kv_idx*NUM_WORKERS + warpid)*q_reg.num_elements, q_reg.cols);
__syncthreads(); // we need to make sure all memory is loaded before we can begin the compute phase

// now each warp goes through all of the subtiles, loads them, and then does the flash attention internal alg.
for(int subtile = 0; subtile < NUM_WORKERS; subtile++) {

load(k_reg, k_smem[subtile]);
load(k_reg, k_smem[subtile]); // load k from shared into registers

zero(att_block);
mma_ABt(att_block, q_reg, k_reg, att_block);
zero(att_block); // zero 16x16 attention tile
mma_ABt(att_block, q_reg, k_reg, att_block); // [email protected]

copy(norm_vec_last, norm_vec);
copy(max_vec_last, max_vec);

row_max(max_vec, att_block, max_vec); // accumulate onto the max_vec
sub_row(att_block, att_block, max_vec);
exp(att_block, att_block);
sub_row(att_block, att_block, max_vec); // subtract max from attention -- now all <=0
exp(att_block, att_block); // exponentiate the block in-place.

sub(max_vec_last, max_vec_last, max_vec);
exp(max_vec_last, max_vec_last);
mul(norm_vec, norm_vec, max_vec_last);
sub(max_vec_last, max_vec_last, max_vec); // subtract new max from old max to find the new normalization.
exp(max_vec_last, max_vec_last); // exponentiate this vector -- this is what we need to normalize by.
mul(norm_vec, norm_vec, max_vec_last); // and the norm vec is now normalized.

row_sum(norm_vec, att_block, norm_vec); // accumulate onto the norm_vec
div_row(att_block, att_block, norm_vec);
row_sum(norm_vec, att_block, norm_vec); // accumulate the new attention block onto the now-rescaled norm_vec
div_row(att_block, att_block, norm_vec); // now the attention block is correctly normalized

mul(norm_vec_last, norm_vec_last, max_vec_last);
div(norm_vec_last, norm_vec_last, norm_vec);
mul(norm_vec_last, norm_vec_last, max_vec_last); // normalize the previous norm vec according to the new max
div(norm_vec_last, norm_vec_last, norm_vec); // normalize the previous norm vec according to the new norm

copy(att_block_mma, att_block); // convert to bf16 for mma_AB

load(v_reg, v_smem[subtile]);
load(v_reg, v_smem[subtile]); // load v from shared into registers.
rt_bf_1x4<ducks::rt_layout::col> &v_reg_col = swap_layout_inplace(v_reg); // this is a reference and the call has invalidated v_reg

mul_row(o_prev, o_prev, norm_vec_last); // normalize o_prev in advance of mma_AB'ing onto it
mma_AB(o_prev, att_block_mma, v_reg_col, o_prev);
mul_row(o_reg, o_reg, norm_vec_last); // normalize o_reg in advance of mma_AB'ing onto it
mma_AB(o_reg, att_block_mma, v_reg_col, o_reg); // mfma onto o_reg with the local attention@V matmul.
}
__syncthreads(); // we need to make sure all warps are done before we can start loading the next kv chunk
}

store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_prev, d); // write out o. compiler has an issue with register usage if d is made constexpr q_reg.rows :/
store(_o + (q_blk*NUM_WORKERS + warpid)*q_reg.num_elements, o_reg, q_reg.cols); // write out o. compiler has an issue with register usage if d is made constexpr q_reg.rows :/
}
}

#include "harness.impl"
#include "harness.impl"
6 changes: 3 additions & 3 deletions examples/attn/4090/harness.impl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ int main(int argc, char **argv) {
unsigned long mem_size = kittens::MAX_SHARED_MEMORY; // have the flag tell us

cudaFuncSetAttribute(
attend_ker,
attend_ker64,
cudaFuncAttributeMaxDynamicSharedMemorySize,
mem_size
);
Expand All @@ -95,7 +95,7 @@ int main(int argc, char **argv) {
std::cout << "Starting kernel\n";
const auto start = std::chrono::high_resolution_clock::now();
for(int i = 0; i < ITER; i++) {
attend_ker<<<ATTN_B*ATTN_H, BLOCK_SIZE, mem_size>>>(ATTN_N, 64, d_q, d_k, d_v, d_o);
attend_ker64<<<ATTN_B*ATTN_H, BLOCK_SIZE, mem_size>>>(ATTN_N, d_q, d_k, d_v, d_o);
}
cudaDeviceSynchronize();
const auto finish = std::chrono::high_resolution_clock::now();
Expand Down Expand Up @@ -135,4 +135,4 @@ int main(int argc, char **argv) {
delete[] q_bf, k_bf, v_bf, o_bf;

return 0;
}
}
25 changes: 21 additions & 4 deletions examples/based/linear_attn_forward/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,9 @@

Here we provide details to test and benchmark the Based kernel's *forward pass / inference prefill*. Note this kernel assumes feature dimension 16.

Using TK, we achieve a fast implementation of linear attention for the Based architecture! You can checkout these resources to learn more about the Based architecture: [Code](https://github.com/HazyResearch/based), [Paper](https://arxiv.org/abs/2402.18668).


## Overview of kernel
Based introduces a fast implementation of linear attention.
Using TK, we achieve a fast implementation of linear attention for the Based architecture! You can checkout these resources to learn more about the Based architecture: [Code](https://github.com/HazyResearch/based), [Paper](https://arxiv.org/abs/2402.18668).

Standard attention computes an $O(N^2)$ matrix of query and key interactions $\exp(q_i^Tk_j/\sqrt{d})$. The idea in [linear attention](https://arxiv.org/abs/2006.16236) is to remove the softmax around the query-key dot product:

Expand All @@ -29,7 +27,26 @@ Details of this prefill kernel are provided in [Algorithm 1 of the Based paper](

$$y_i = (\phi(q_i)^T\phi(k_i))v_i + \phi(q_i)\sum_{j=1}^{i-1}\phi(k_j)^Tv_j$$

The left-hand term requires causal computation on the tile, but the right-hand term is a simple matrix multiply (cuasality has already been handled)! We partition across workers to store state $s$ in registers throughput! After streaming each chunk of tokens and computing chunks of output as shown above, we update the state.
The left-hand term computes the parallel view (multiplying queries and keys first), then applies causal masking on the tile, and muiltiplies by values. This is handled in the kernel as follows:
```
load(q, q_s[warpid]);
load(k, k_s[warpid]);
zero(local_attn);
mma_ABt(local_attn, q, k, local_attn);
make_causal(local_attn_bf, local_attn_bf, kittens::base_types::constants<bf16>::zero());
load(v, v_s[warpid]);
auto &v_col = swap_layout_inplace(v); // prepare for MMA
zero(o);
mma_AB(o, local_attn_bf, v_col, o);
```
The right-hand term is a simple matrix multiply (cuasality has already been handled from previous iterations over chunks of the sequence)! We partition across workers to store state $s$ in registers throughput! After streaming each chunk of tokens and computing chunks of output as shown above, we update the state:
```
// Updating the KV state using the keys and values for the current chunk
mma_AB(a2, kt, v_col, a2); // accumulate onto a2
```


## Baselines
Expand Down
19 changes: 19 additions & 0 deletions examples/hedgehog/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@


This kernel is for the [Hedgehog linear attention architecture](https://arxiv.org/abs/2402.04347) prefill stage. The structure of this kernel is similar to the Based architecture kernel -- you can read the README of that example for more background!

You can test it out via our C++ harness:
```
python generate_tests.py randn
# Ensure that the hedgehog_fwd_tk function and its imports are commented out in hedgehog.cu
# Ensure harness.impl is being imported (line is uncommented)
make clean && make && ./hedgehog randn.txt
```

You can also try it with PyTorch:
```
# Ensure harness.impl is commented out
python setup.py install # ensure that you have run ```source env.src''' prior to this
python hedgehog_profile.py
```

Loading

0 comments on commit a562ed2

Please sign in to comment.