Skip to content

Commit

Permalink
update readme
Browse files Browse the repository at this point in the history
  • Loading branch information
drisspg committed Jul 31, 2024
1 parent 5d671da commit 791ad2c
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ from torch.nn.attention.flex_attention import flex_attention, create_block_mask
from attn_gym.masks import causal_mask

# Create a causal mask
block_mask: BlockMask = create_block_mask(causal_mask)
Q_LEN, KV_LEN = query.size(-2), key.size(-2)
block_mask: BlockMask = create_block_mask(causal_mask, 1, 1, Q_LEN, KV_LEN)

# Use FlexAttention with a causal mask modification
output = flex_attention(query, key, value, block_mask=causal_mask)
Expand Down

0 comments on commit 791ad2c

Please sign in to comment.