@@ -439,7 +439,6 @@ def forward(
439
439
"""
440
440
# 1) Project key, value, and query.
441
441
# as a reminder at training layer_cache[0] remains False
442
- key_pad_mask = self .layer_cache [1 ].get ("key_pad_mask" , None )
443
442
if self .layer_cache [0 ]:
444
443
# Retrieve keys and values from the KV cache (decoding mode only).
445
444
if self .attn_type == "self" :
@@ -484,6 +483,16 @@ def forward(
484
483
key = key [:, :, 1 :, :]
485
484
value = value [:, :, 1 :, :]
486
485
486
+ if step == 0 :
487
+ key_pad_mask = self .layer_cache [1 ].get ("key_pad_mask" , None )
488
+ if key_pad_mask is not None :
489
+ x = key_pad_mask .expand (
490
+ - 1 , self .head_count // self .parallel_gpu , - 1
491
+ )
492
+ x = x .unsqueeze (3 )
493
+ x = x .expand (- 1 , - 1 , - 1 , value .size (3 ))
494
+ value = value .masked_fill (x , 0 )
495
+
487
496
self .layer_cache [1 ]["keys" ] = key
488
497
self .layer_cache [1 ]["values" ] = value
489
498
@@ -565,19 +574,6 @@ def forward(
565
574
self .layer_cache [1 ]["keys" ] = key
566
575
self .layer_cache [1 ]["values" ] = value
567
576
568
- if key_pad_mask is not None :
569
- # Increase the cached key pad mask by concatenation.
570
- # For decoding only.
571
- if step > 0 :
572
- y = torch .zeros (
573
- (key_pad_mask .size (0 ), key_pad_mask .size (1 ), 1 ),
574
- dtype = torch .bool ,
575
- device = key_pad_mask .device ,
576
- )
577
- self .layer_cache [1 ]["key_pad_mask" ] = torch .cat (
578
- (key_pad_mask , y ), 2
579
- )
580
- key_pad_mask = self .layer_cache [1 ]["key_pad_mask" ]
581
577
else :
582
578
# Retrieve keys and values from linear layers (training mode).
583
579
key = self .maybe_ckpt (self .linear_keys , key )
@@ -706,8 +702,6 @@ def forward(
706
702
scores = self .alibi (scores )
707
703
708
704
scores = scores .float ()
709
- if key_pad_mask is not None and mask is None :
710
- mask = key_pad_mask .unsqueeze (1 )
711
705
712
706
if mask is not None :
713
707
# not 100% necessary but expand to nb of heads
@@ -727,10 +721,6 @@ def forward(
727
721
attn_output .add_ (relative_matmul (drop_attn , relations_values , False ))
728
722
729
723
context = unshape (attn_output )
730
- if key_pad_mask is not None :
731
- if key_pad_mask .size (0 ) > 1 and context .size (1 ) > 1 :
732
- x = key_pad_mask .squeeze (1 ).unsqueeze (2 ).expand (- 1 , - 1 , context .size (2 ))
733
- context = context .masked_fill (x , 0 )
734
724
735
725
if self .layer_cache [0 ]:
736
726
attn_output = self .final_linear (context )
0 commit comments