-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathflash_rerope.py
808 lines (731 loc) · 36.8 KB
/
flash_rerope.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
# --------------------------------------------------------
# the flash_attn algorithm implemented in Triton is refered to:
# https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py
# and the flash attn version we support is 2.2.1
# NOTE that:
# the triton version we support is 2.1.0.dev20231014192330,
# which is the nightly version that can be installed with the command below:
# pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
# --------------------------------------------------------
import math
import torch
import flash_attn.flash_attn_interface as fi
import triton
import triton.language as tl
### NOTE: this right one let the computation of q1,k2,q2,k1 outside the kernel
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_kernel_with_fused_rerope_outter(
Q1, K1, Q2, K2, V,
Bias, Out,
Lse, # shape: [bs, nh, q_len_round]
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug, with # shape: [bs, nh, q_len_round]
softmax_scale,
stride_q1b, stride_q1h, stride_q1m, # q_len
stride_k1b, stride_k1h, stride_k1n, # kv_len
stride_q2b, stride_q2h, stride_q2m, # q_len
stride_k2b, stride_k2h, stride_k2n, # kv_len
stride_vb, stride_vh, stride_vn, # kv_len
stride_bb, stride_bh, stride_bm, # q_len
stride_ob, stride_oh, stride_om, # q_len
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, WINDOW: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, # usually d itself
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
## get the start id in m and n dim
start_m = tl.program_id(0) # ith block of Q
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
## initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # the ith block in q_len dim (outer loop)
offs_n = tl.arange(0, BLOCK_N) # the initial 0th block in kv_len dim (inner loop)
offs_d = tl.arange(0, BLOCK_HEADDIM) # the only one block in head_dim dim
## Initialize pointers to Q, K, V, and rotate-half Q, K
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# https://github.com/openai/triton/issues/741
# I'm seeing a tiny bit of difference (5-7us)
q1_ptrs = (
Q1 + off_b * stride_q1b + off_h * stride_q1h +
(offs_m[:, None] * stride_q1m + offs_d[None, :]) # shape of [block_sz_m, hd]
)
k1_ptrs = (
K1 + off_b * stride_k1b + off_h * stride_k1h +
(offs_n[:, None] * stride_k1n + offs_d[None, :]) # shape of [block_sz_n, hd]
)
q2_ptrs = (
Q2 + off_b * stride_q2b + off_h * stride_q2h +
(offs_m[:, None] * stride_q2m + offs_d[None, :]) # shape of [block_sz_m, hd]
)
k2_ptrs = (
K2 + off_b * stride_k2b + off_h * stride_k2h +
(offs_n[:, None] * stride_k2n + offs_d[None, :]) # shape of [block_sz_n, hd]
)
v_ptrs = (
V + off_b * stride_vb + off_h * stride_vh +
(offs_n[:, None] * stride_vn + offs_d[None, :]) # shape of [block_sz_n, hd]
)
## initialize bias pointers
if BIAS_TYPE == "vector":
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == "matrix":
b_ptrs = (
Bias + off_b * stride_bb + off_h * stride_bh +
(offs_m[:, None] * stride_bm + offs_n[None, :])
)
## initialize pointer to m, l, o
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # init to -inf for li
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # init to -inf for mi
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # init the empty output, with shpe: [block_sz_m, hd]
## load q1, q2 block
# on different conditions whether the q_len / kv_len can be divided by block_sz_m / block_sz_n
# and they will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q1 = tl.load(q1_ptrs)
q2 = tl.load(q2_ptrs)
else:
q1 = tl.load(q1_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
q2 = tl.load(q2_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q1 = tl.load(q1_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
q2 = tl.load(q2_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q1 = tl.load(q1_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
q2 = tl.load(q2_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
## loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) # jth block of K,V
# -- load k1, k2 ----
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
k1 = tl.load(k1_ptrs + start_n * stride_k1n)
k2 = tl.load(k2_ptrs + start_n * stride_k2n)
else:
k1 = tl.load(k1_ptrs + start_n * stride_k1n, mask=offs_d[None, :] < headdim, other=0.0)
k2 = tl.load(k2_ptrs + start_n * stride_k2n, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k1 = tl.load(k1_ptrs + start_n * stride_k1n, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
k2 = tl.load(k2_ptrs + start_n * stride_k2n, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
else:
k1 = tl.load(k1_ptrs + start_n * stride_k1n, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
k2 = tl.load(k2_ptrs + start_n * stride_k2n, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
# -- compute qk1, qk2 ----
qk1 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # shape: [block_sz_m, block_sz_n]
qk2 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # shape: [block_sz_m, block_sz_n]
# qk += tl.dot(q, k, trans_b=True) # Q * K^T => get wrong using triton-nightly 2.1.0
qk1 += tl.dot(q1, tl.trans(k1)) # Q1 * K1^T
qk2 += tl.dot(q2, tl.trans(k2)) # Q2 * K2^T
# -- apply rectified mask to get qk --
reM = tl.abs(offs_m[:, None] - (start_n + offs_n)[None, :]) < WINDOW
qk = tl.where(reM, qk1, qk2)
# -- apply causal mask --
# Trying to combine the two masks seem to make the result wrong
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
# -- compute p and update mij, lij with bias applied if exists --
if BIAS_TYPE != "none":
if BIAS_TYPE == "vector":
if EVEN_N:
bias = tl.load(b_ptrs + start_n).to(tl.float32)
else:
bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)
bias = bias[None, :]
elif BIAS_TYPE == "matrix":
if EVEN_M & EVEN_N:
bias = tl.load(b_ptrs + start_n).to(tl.float32)
else:
bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
qk = qk * softmax_scale + bias
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
p = tl.exp(qk - m_ij[:, None])
else:
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
p = tl.exp(qk * softmax_scale - m_ij[:, None])
l_ij = tl.sum(p, 1) # temporary lij as rowsum(Pij)
# -- scale acc_o --
acc_o_scale = tl.exp(m_i - m_ij)
tl.store(t_ptrs, acc_o_scale) # BUG: have to store
acc_o_scale = tl.load(t_ptrs) # BUG: and immediately load
acc_o = acc_o * acc_o_scale[:, None]
# -- load v --
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# -- update acc_o by PV --
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# -- update mi and li --
m_i = m_ij
l_i_new = tl.exp(lse_i - m_ij) + l_ij
lse_i = m_ij + tl.log(l_i_new)
## scale the Oi
o_scale = tl.exp(m_i - lse_i)
# BUG: have to store and immediately load
tl.store(t_ptrs, o_scale)
o_scale = tl.load(t_ptrs)
acc_o = acc_o * o_scale[:, None]
## store Li
# re-materialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
tl.store(lse_ptrs, lse_i)
## initialize pointers to output
offs_d = tl.arange(0, BLOCK_HEADDIM)
out_ptrs = (
Out + off_b * stride_ob + off_h * stride_oh +
(offs_m[:, None] * stride_om + offs_d[None, :])
)
## store Oi
if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o)
else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
def _flash_attn_forward_with_fused_rerope_outter(
q, k, v,
cos, sin, position_ids, window_size,
bias=None, causal=False, softmax_scale=None):
## check constraints in shape, dtype, device and index boundary
batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape
max_position_len, _ = cos.shape
assert k.shape == (batch, seqlen_k, nheads, d)
assert v.shape == (batch, seqlen_k, nheads, d)
assert position_ids.shape == (batch, seqlen_q)
assert d <= 128, "FlashAttention only support head dimensions up to 128"
assert q.dtype == k.dtype == v.dtype == cos.dtype == sin.dtype, "All tensors must have the same type"
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.is_cuda and k.is_cuda and v.is_cuda and cos.is_cuda and sin.is_cuda and position_ids.is_cuda
assert position_ids.max() < max_position_len
assert window_size <= seqlen_k
assert window_size <= max_position_len
## prepare rerope q1, k1, q2, k2
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
rhq, rhk = rotate_half(q), rotate_half(k)
cos1 = cos[position_ids].unsqueeze(2).expand((batch, seqlen_q, nheads, d))
sin1 = sin[position_ids].unsqueeze(2).expand((batch, seqlen_q, nheads, d))
cos2 = cos[position_ids * 0 + window_size].unsqueeze(2).expand((batch, seqlen_q, nheads, d))
sin2 = sin[position_ids * 0 + window_size].unsqueeze(2).expand((batch, seqlen_q, nheads, d))
q1 = q * cos1 + rhq * sin1
k1 = k * cos1 + rhk * sin1
q2 = q * cos2 + rhq * sin2
k2 = k
## prepare scaling factor
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
## prepare bias
has_bias = bias is not None
bias_type = "none"
if has_bias:
assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda
assert bias.dim() == 4
if bias.stride(-1) != 1:
bias = bias.contiguous()
if bias.shape[2:] == (1, seqlen_k):
bias_type = "vector"
elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = "matrix"
else:
raise RuntimeError(
"Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
)
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
## prepare O, L, tmp
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 # the closest q_len which is times of 128
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
o = torch.empty_like(q)
## prepare block / warp / grid size
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) # usually d itself
## set block size
BLOCK = 128
# BLOCK = 256 # FIXME: out of resource
## set num of warps
# num_warps = 4 if d <= 64 else 8
num_warps = 4 # FIXME: when d = 128, not all close if num_warps = 8 as the line above
## set grid split
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
## apply forward kernel with fused rope func for all blocks in grid
_fwd_kernel_with_fused_rerope_outter[grid](
q1, k1, q2, k2, v,
bias, # shape: [bs, nh, q_len, kv_len]
o, # shape: [bs, q_len, nh, hd]
lse, tmp, # shape: [bs, nh, q_len_round]
softmax_scale,
q1.stride(0), q1.stride(2), q1.stride(1), # bs, nh, sq
k1.stride(0), k1.stride(2), k1.stride(1), # bs, nh, sk
q2.stride(0), q2.stride(2), q2.stride(1), # bs, nh, sq
k2.stride(0), k2.stride(2), k2.stride(1), # bs, nh, sk
v.stride(0), v.stride(2), v.stride(1), # bs, nh, sv
# bias
*bias_strides,
o.stride(0), o.stride(2), o.stride(1), # bs, nh, sq
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type, causal, window_size, # rerope window size
BLOCK_HEADDIM, BLOCK_M=BLOCK, BLOCK_N=BLOCK,
num_warps=num_warps, num_stages=1,
)
return o, lse, softmax_scale # softmax_scale could have been updated
#### NOTE: this also right one let the computation of q1,k2,q2,k1 inside the kernel
@triton.heuristics(
{
"EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
"EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
"EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
}
)
@triton.jit
def _fwd_kernel_with_fused_rerope_inner(
Q, K, V, rhQ, rhK,
Cos1, Sin1, Cos2, Sin2,
Bias, Out,
Lse, # shape: [bs, nh, q_len_round]
TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug, with # shape: [bs, nh, q_len_round]
softmax_scale,
stride_qb, stride_qh, stride_qm, # q_len
stride_kb, stride_kh, stride_kn, # kv_len
stride_vb, stride_vh, stride_vn, # kv_len
stride_rhqb, stride_rhqh, stride_rhqm, # q_len
stride_rhkb, stride_rhkh, stride_rhkn, # kv_len
stride_qc1b, stride_qc1h, stride_qc1m, # q_len
stride_qs1b, stride_qs1h, stride_qs1m, # q_len
stride_kc1b, stride_kc1h, stride_kc1n, # kv_len
stride_ks1b, stride_ks1h, stride_ks1n, # kv_len
stride_qc2b, stride_qc2h, stride_qc2m, # q_len
stride_qs2b, stride_qs2h, stride_qs2m, # q_len
stride_bb, stride_bh, stride_bm, # q_len
stride_ob, stride_oh, stride_om, # q_len
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, headdim,
CACHE_KEY_SEQLEN_Q, CACHE_KEY_SEQLEN_K,
BIAS_TYPE: tl.constexpr, IS_CAUSAL: tl.constexpr, WINDOW: tl.constexpr,
BLOCK_HEADDIM: tl.constexpr, # usually d itself
EVEN_M: tl.constexpr, EVEN_N: tl.constexpr, EVEN_HEADDIM: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr,
):
## get the start id in m and n dim
start_m = tl.program_id(0) # ith block of Q
off_hb = tl.program_id(1)
off_b = off_hb // nheads
off_h = off_hb % nheads
## initialize offsets
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) # the ith block in q_len dim (outer loop)
offs_n = tl.arange(0, BLOCK_N) # the initial 0th block in kv_len dim (inner loop)
offs_d = tl.arange(0, BLOCK_HEADDIM) # the only one block in head_dim dim
## Initialize pointers to Q, K, V, and rotate-half Q, K
# Adding parenthesis around indexing might use int32 math instead of int64 math?
# https://github.com/openai/triton/issues/741
# I'm seeing a tiny bit of difference (5-7us)
q_ptrs = (
Q + off_b * stride_qb + off_h * stride_qh +
(offs_m[:, None] * stride_qm + offs_d[None, :]) # shape of [block_sz_m, hd]
)
k_ptrs = (
K + off_b * stride_kb + off_h * stride_kh +
(offs_n[:, None] * stride_kn + offs_d[None, :]) # shape of [block_sz_n, hd]
)
v_ptrs = (
V + off_b * stride_vb + off_h * stride_vh +
(offs_n[:, None] * stride_vn + offs_d[None, :]) # shape of [block_sz_n, hd]
)
rhq_ptrs = (
rhQ + off_b * stride_rhqb + off_h * stride_rhqh +
(offs_m[:, None] * stride_rhqm + offs_d[None, :]) # shape of [block_sz_m, hd]
)
rhk_ptrs = (
rhK + off_b * stride_rhkb + off_h * stride_rhkh +
(offs_n[:, None] * stride_rhkn + offs_d[None, :]) # shape of [block_sz_n, hd]
)
## Initialize pointers to Cos1(q/k), Sin1(q/k), Cos2(q/k), Sin2(q/k) and rectified mask
q_cos1_ptrs = (
Cos1 + off_b * stride_qc1b + off_h * stride_qc1h +
(offs_m[:, None] * stride_qc1m + offs_d[None, :]) # shape of [block_sz_m, hd]
)
q_sin1_ptrs = (
Sin1 + off_b * stride_qs1b + off_h * stride_qs1h +
(offs_m[:, None] * stride_qs1m + offs_d[None, :]) # shape of [block_sz_m, hd]
)
k_cos1_ptrs = (
Cos1 + off_b * stride_kc1b + off_h * stride_kc1h +
(offs_n[:, None] * stride_kc1n + offs_d[None, :]) # shape of [block_sz_n, hd]
)
k_sin1_ptrs = (
Sin1 + off_b * stride_ks1b + off_h * stride_ks1h +
(offs_n[:, None] * stride_ks1n + offs_d[None, :]) # shape of [block_sz_n, hd]
)
q_cos2_ptrs = (
Cos2 + off_b * stride_qc2b + off_h * stride_qc2h +
(offs_m[:, None] * stride_qc2m + offs_d[None, :]) # shape of [block_sz_m, hd]
)
q_sin2_ptrs = (
Sin2 + off_b * stride_qs2b + off_h * stride_qs2h +
(offs_m[:, None] * stride_qs2m + offs_d[None, :]) # shape of [block_sz_m, hd]
)
## initialize bias pointers
if BIAS_TYPE == "vector":
b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n
elif BIAS_TYPE == "matrix":
b_ptrs = (
Bias + off_b * stride_bb + off_h * stride_bh +
(offs_m[:, None] * stride_bm + offs_n[None, :])
)
## initialize pointer to m, l, o
t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m
lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # init to -inf for li
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") # init to -inf for mi
acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) # init the empty output, with shpe: [block_sz_m, hd]
## load q block and its rotate_half with cos/sin rope
# on different conditions whether the q_len / kv_len can be divided by block_sz_m / block_sz_n
# and they will stay in SRAM throughout
# [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
# tl.load(q_ptrs), we get the wrong output!
if EVEN_M & EVEN_N:
if EVEN_HEADDIM:
q = tl.load(q_ptrs)
rhq = tl.load(rhq_ptrs)
q_cos1 = tl.load(q_cos1_ptrs)
q_sin1 = tl.load(q_sin1_ptrs)
q_cos2 = tl.load(q_cos2_ptrs)
q_sin2 = tl.load(q_sin2_ptrs)
else:
q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
rhq = tl.load(rhq_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
q_cos1 = tl.load(q_cos1_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
q_sin1 = tl.load(q_sin1_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
q_cos2 = tl.load(q_cos2_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
q_sin2 = tl.load(q_sin2_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
rhq = tl.load(rhq_ptrs, mask=offs_d[None, :] < headdim, other=0.0)
q_cos1 = tl.load(q_cos1_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
q_sin1 = tl.load(q_sin1_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
q_cos2 = tl.load(q_cos2_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
q_sin2 = tl.load(q_sin2_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0)
else:
q = tl.load(q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
rhq = tl.load(rhq_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
q_cos1 = tl.load(q_cos1_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
q_sin1 = tl.load(q_sin1_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
q_cos2 = tl.load(q_cos2_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
q_sin2 = tl.load(q_sin2_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0)
## apply rope to q1, q2
q1 = q * q_cos1 + rhq * q_sin1
q2 = q * q_cos2 + rhq * q_sin2
## loop over k, v and update accumulator
end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
for start_n in range(0, end_n, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N) # jth block of K,V
# -- load k and its rotate_half with cos/sin rope ----
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn)
rhk = tl.load(rhk_ptrs + start_n * stride_rhkn)
k_cos1 = tl.load(k_cos1_ptrs + start_n * stride_kc1n)
k_sin1 = tl.load(k_sin1_ptrs + start_n * stride_ks1n)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0)
rhk = tl.load(rhk_ptrs + start_n * stride_rhkn, mask=offs_d[None, :] < headdim, other=0.0)
k_cos1 = tl.load(k_cos1_ptrs + start_n * stride_kc1n, mask=offs_d[None, :] < headdim, other=0.0)
k_sin1 = tl.load(k_sin1_ptrs + start_n * stride_ks1n, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
k = tl.load(k_ptrs + start_n * stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
rhk = tl.load(rhk_ptrs + start_n * stride_rhkn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
k_cos1 = tl.load(k_cos1_ptrs + start_n * stride_kc1n, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
k_sin1 = tl.load(k_sin1_ptrs + start_n * stride_ks1n, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0)
else:
k = tl.load(k_ptrs + start_n * stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
rhk = tl.load(rhk_ptrs + start_n * stride_rhkn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
k_cos1 = tl.load(k_cos1_ptrs + start_n * stride_kc1n, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
k_sin1 = tl.load(k_sin1_ptrs + start_n * stride_ks1n, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0)
# -- apply rope to k1, k2 ----
k1 = k * k_cos1 + rhk * k_sin1
k2 = k
# -- compute qk1, qk2 ----
qk1 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # shape: [block_sz_m, block_sz_n]
qk2 = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # shape: [block_sz_m, block_sz_n]
# qk += tl.dot(q, k, trans_b=True) # Q * K^T => get wrong using triton-nightly 2.1.0
qk1 += tl.dot(q1, tl.trans(k1)) # Q1 * K1^T
qk2 += tl.dot(q2, tl.trans(k2)) # Q2 * K2^T
# -- apply rectified mask to get qk --
reM = tl.abs(offs_m[:, None] - (start_n + offs_n)[None, :]) < WINDOW
qk = tl.where(reM, qk1, qk2)
# -- apply causal mask --
# Trying to combine the two masks seem to make the result wrong
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))
if IS_CAUSAL:
qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
# -- compute p and update mij, lij with bias applied if exists --
if BIAS_TYPE != "none":
if BIAS_TYPE == "vector":
if EVEN_N:
bias = tl.load(b_ptrs + start_n).to(tl.float32)
else:
bias = tl.load(b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0).to(tl.float32)
bias = bias[None, :]
elif BIAS_TYPE == "matrix":
if EVEN_M & EVEN_N:
bias = tl.load(b_ptrs + start_n).to(tl.float32)
else:
bias = tl.load(b_ptrs + start_n, mask=(offs_m[:, None] < seqlen_q) & ((start_n + offs_n)[None, :] < seqlen_k), other=0.0).to(tl.float32)
# Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
# can then fuse the mult and add into an fma instruction. But if we have bias we need to
# to multiply with softmax_scale here.
qk = qk * softmax_scale + bias
m_ij = tl.maximum(tl.max(qk, 1), lse_i)
p = tl.exp(qk - m_ij[:, None])
else:
m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i)
p = tl.exp(qk * softmax_scale - m_ij[:, None])
l_ij = tl.sum(p, 1) # temporary lij as rowsum(Pij)
# -- scale acc_o --
acc_o_scale = tl.exp(m_i - m_ij)
tl.store(t_ptrs, acc_o_scale) # BUG: have to store
acc_o_scale = tl.load(t_ptrs) # BUG: and immediately load
acc_o = acc_o * acc_o_scale[:, None]
# -- load v --
if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition
if EVEN_HEADDIM:
v = tl.load(v_ptrs + start_n * stride_vn)
else:
v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0)
else:
if EVEN_HEADDIM:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=(start_n + offs_n)[:, None] < seqlen_k,
other=0.0,
)
else:
v = tl.load(
v_ptrs + start_n * stride_vn,
mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
other=0.0,
)
# -- update acc_o by PV --
p = p.to(v.dtype)
acc_o += tl.dot(p, v)
# -- update mi and li --
m_i = m_ij
l_i_new = tl.exp(lse_i - m_ij) + l_ij
lse_i = m_ij + tl.log(l_i_new)
## scale the Oi
o_scale = tl.exp(m_i - lse_i)
# BUG: have to store and immediately load
tl.store(t_ptrs, o_scale)
o_scale = tl.load(t_ptrs)
acc_o = acc_o * o_scale[:, None]
## store Li
# re-materialize offsets to save registers
start_m = tl.program_id(0)
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
# write back l and m
lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m
tl.store(lse_ptrs, lse_i)
## initialize pointers to output
offs_d = tl.arange(0, BLOCK_HEADDIM)
out_ptrs = (
Out + off_b * stride_ob + off_h * stride_oh +
(offs_m[:, None] * stride_om + offs_d[None, :])
)
## store Oi
if EVEN_M:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o)
else:
tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim)
else:
if EVEN_HEADDIM:
tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q)
else:
tl.store(
out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
)
def _flash_attn_forward_with_fused_rerope_inner(q, k, v,
cos, sin, position_ids, window_size,
bias=None, causal=False, softmax_scale=None):
## check constraints in shape, dtype, device and index boundary
batch, seqlen_q, nheads, d = q.shape
_, seqlen_k, _, _ = k.shape
max_position_len, _ = cos.shape
assert k.shape == (batch, seqlen_k, nheads, d)
assert v.shape == (batch, seqlen_k, nheads, d)
assert position_ids.shape == (batch, seqlen_q)
assert d <= 128, "FlashAttention only support head dimensions up to 128"
assert q.dtype == k.dtype == v.dtype == cos.dtype == sin.dtype, "All tensors must have the same type"
assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
assert q.is_cuda and k.is_cuda and v.is_cuda and cos.is_cuda and sin.is_cuda and position_ids.is_cuda
assert position_ids.max() < max_position_len
assert window_size <= seqlen_k
assert window_size <= max_position_len
## prepare rerope
# 1. first set is the rope with position ids: [0,1,...,seq_len], with shape: [bs, q_len, nheads, dim]
# 2. second set is the rope with constant position ids: [w,w,...,w], with shape: [bs, q_len, nheads, dim]
# 3. the rectified mask to select which set to apply according to the relative distance, with shape: [bs, nh, q_len, kv_len]
cos1 = cos[position_ids].unsqueeze(2).expand((batch, seqlen_q, nheads, d))
sin1 = sin[position_ids].unsqueeze(2).expand((batch, seqlen_q, nheads, d))
cos2 = cos[position_ids * 0 + window_size].unsqueeze(2).expand((batch, seqlen_q, nheads, d))
sin2 = sin[position_ids * 0 + window_size].unsqueeze(2).expand((batch, seqlen_q, nheads, d))
# reM = ((position_ids[:, -seqlen_q:, None] - position_ids[:, None]).abs() < window_size).unsqueeze(1).expand(batch, nheads, seqlen_q, seqlen_k)
## prepare rotate_half q,k
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
rhq, rhk = rotate_half(q), rotate_half(k)
cos1, sin1, cos2, sin2, rhq, rhk = [x if x.stride(-1) == 1 else x.contiguous() for x in [cos1, sin1, cos2, sin2, rhq, rhk]]
## prepare scaling factor
softmax_scale = softmax_scale or 1.0 / math.sqrt(d)
## prepare bias
has_bias = bias is not None
bias_type = "none"
if has_bias:
assert bias.dtype in [q.dtype, torch.float]
assert bias.is_cuda
assert bias.dim() == 4
if bias.stride(-1) != 1:
bias = bias.contiguous()
if bias.shape[2:] == (1, seqlen_k):
bias_type = "vector"
elif bias.shape[2:] == (seqlen_q, seqlen_k):
bias_type = "matrix"
else:
raise RuntimeError(
"Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)"
)
bias = bias.expand(batch, nheads, seqlen_q, seqlen_k)
bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0)
## prepare O, L, tmp
seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 # the closest q_len which is times of 128
lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32)
o = torch.empty_like(q)
## prepare block / warp / grid size
BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) # usually d itself
## set block size
BLOCK = 128
# BLOCK = 256 # FIXME: out of resource
## set num of warps
# num_warps = 4 if d <= 64 else 8
num_warps = 4 # FIXME: when d = 128, not all close if num_warps = 8 as the line above
## set grid split
grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
## apply forward kernel with fused rope func for all blocks in grid
_fwd_kernel_with_fused_rerope_inner[grid](
q, k, v, rhq, rhk,
cos1, sin1, cos2, sin2,
bias, # shape: [bs, nh, q_len, kv_len]
o, # shape: [bs, q_len, nh, hd]
lse, tmp, # shape: [bs, nh, q_len_round]
softmax_scale,
q.stride(0), q.stride(2), q.stride(1), # bs, nh, sq
k.stride(0), k.stride(2), k.stride(1), # bs, nh, sk
v.stride(0), v.stride(2), v.stride(1), # bs, nh, sv
rhq.stride(0), rhq.stride(2), rhq.stride(1), # bs, nh, sq
rhk.stride(0), rhk.stride(2), rhk.stride(1), # bs, nh, sk
# for q1
cos1.stride(0), cos1.stride(2), cos1.stride(1), # bs, nh, sq
sin1.stride(0), sin1.stride(2), sin1.stride(1), # bs, nh, sq
# for k1
cos1.stride(0), cos1.stride(2), cos1.stride(1), # bs, nh, sk
sin1.stride(0), sin1.stride(2), sin1.stride(1), # bs, nh, sk
# for q2
cos2.stride(0), cos2.stride(2), cos2.stride(1), # bs, nh, sq
sin2.stride(0), sin2.stride(2), sin2.stride(1), # bs, nh, sq
# bias
*bias_strides,
o.stride(0), o.stride(2), o.stride(1), # bs, nh, sq
nheads, seqlen_q, seqlen_k, seqlen_q_rounded, d,
seqlen_q // 32, seqlen_k // 32, # key for triton cache (limit number of compilations)
# Can't use kwargs here because triton autotune expects key to be args, not kwargs
# IS_CAUSAL=causal, BLOCK_HEADDIM=d,
bias_type, causal, window_size,
BLOCK_HEADDIM, BLOCK_M=BLOCK, BLOCK_N=BLOCK,
num_warps=num_warps, num_stages=1,
)
return o, lse, softmax_scale # softmax_scale could have been updated
### interface func
def flash_attn_func_with_fused_rerope(
q, k, v,
cos, sin,
position_ids,
window_size,
bias=None,
causal=False,
softmax_scale=None,
inner=True):
"""
q: (batch_size, seqlen_q, nheads, headdim)
k, v: (batch_size, seqlen_k, nheads, headdim)
cos, sin: (max_seq_len, headdim)
position_ids: (batch_size, seqlen_q)
window_size: the inner window size as rerope boundary
bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k).
For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k).
ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k)
"""
# Make sure that the last dimension is contiguous
q, k, v, cos, sin = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v, cos, sin]]
# forward the flash attn auto func
forward_func = _flash_attn_forward_with_fused_rerope_inner if inner else _flash_attn_forward_with_fused_rerope_outter
o, _, _ = forward_func(
q, k, v,
cos, sin, position_ids, window_size,
bias=bias, causal=causal, softmax_scale=softmax_scale
)
return o