Skip to content

Commit

Permalink
fix key_padding_mask bug in nnf_multi_head_attention_forward (#1208)
Browse files Browse the repository at this point in the history
  • Loading branch information
MaximilianPi authored Nov 11, 2024
1 parent e7897ae commit 6d277a8
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Authors@R: c(
person("Krzysztof", "Joachimiak", role = c("ctb")),
person("Hamada S.", "Badr", role = c("ctb")),
person("Sebastian", "Fischer", role = c("ctb")),
person("Maximilian", "Pichler", role = c("ctb")),
person(family = "RStudio", role = c("cph"))
)
Description: Provides functionality to define and train neural networks similar to
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Bug fixes

- `torch_iinfo()` now support all integer dtypes (#1190 @cregouby)
- Fixed float key_padding_mask in `nnf_multi_head_attention_forward()` (#1205)

# torch 0.13.0

Expand Down
12 changes: 8 additions & 4 deletions R/nnf-activation.R
Original file line number Diff line number Diff line change
Expand Up @@ -728,10 +728,14 @@ nnf_multi_head_attention_forward <- function(query, # type: Tensor

if (!is.null(key_padding_mask)) {
attn_output_weights <- attn_output_weights$view(c(bsz, num_heads, tgt_len, src_len))
attn_output_weights <- attn_output_weights$masked_fill(
key_padding_mask$unsqueeze(2)$unsqueeze(3),
-Inf
)
if (key_padding_mask$dtype == torch_bool()) {
attn_output_weights <- attn_output_weights$masked_fill(
key_padding_mask$unsqueeze(2)$unsqueeze(3),
-Inf
)
} else {
attn_output_weights <- attn_output_weights + key_padding_mask$unsqueeze(2)$unsqueeze(3)
}
attn_output_weights <- attn_output_weights$view(c(
bsz * num_heads,
tgt_len,
Expand Down

0 comments on commit 6d277a8

Please sign in to comment.