Skip to content

Commit

Permalink
SelfAttention layer
Browse files Browse the repository at this point in the history
  • Loading branch information
Vermeille committed Mar 4, 2024
1 parent d276016 commit 8143f55
Showing 1 changed file with 50 additions and 0 deletions.
50 changes: 50 additions & 0 deletions torchelie/nn/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import math
import torch
import torch.nn as nn
import torchelie.utils as tu


class Rotary(torch.nn.Module):
Expand Down Expand Up @@ -42,3 +45,50 @@ def rotate_half(self, x):
def apply_rotary_pos_emb(self, q, k, v, cos, sin):
return (q * cos) + (self.rotate_half(q) *
sin), (k * cos) + (self.rotate_half(k) * sin), v


class SelfAttention(nn.Module):
"""
Self-attention layer
Assumes input of shape (b, l, hidden_size). Uses scaled dot-product
attention and rotary positional embeddings.
Args:
hidden_size (int): size of the hidden dimension
num_heads (int): number of heads
head_size (int): size of each head
causal (bool, optional): whether to apply causal masking. Defaults to True.
"""

def __init__(self, hidden_size, num_heads, head_size, causal=True):
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.qkv = tu.normal_init(
nn.Linear(hidden_size, head_size * num_heads * 3, bias=False),
math.sqrt(2 / (5 * hidden_size)))
self.fc = tu.xavier(
nn.Linear(head_size * num_heads, hidden_size, bias=False))
self.rotary = Rotary(head_size)
self.causal = causal

def forward(self, x, kv_cache=None):
b, l, h, d = x.shape[0], x.shape[1], self.num_heads, self.head_size
# bld -> (q/k/v)bhld
qkv = self.qkv(x).reshape(b, l, 3, h, d).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2]
if kv_cache is not None:
# update k, v
k, v = (torch.cat([kv_cache[0], k],
dim=2), torch.cat([kv_cache[1], v], dim=2))
# update cache
kv_cache[:] = [k, v]
q, k, v = self.rotary(q, k, v)
att = nn.functional.scaled_dot_product_attention(
q, k, v, is_causal=kv_cache is None or self.causal)
# bhld -> blhd
att = att.permute(0, 2, 1, 3).contiguous().reshape(b, l, h * d)
return self.fc(att)

def extra_repr(self):
return f"hidden_size={self.qkv.in_features}, num_heads={self.num_heads}, head_size={self.head_size}, causal={self.causal}"

0 comments on commit 8143f55

Please sign in to comment.