-
-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement dot_product_attention (#455)
* add dot_product_attention * run tests * docs * address some review comments * fix tests * fix fdrop * additional method * bias is positional argument * test bias * fix tests on julia 1.6 * typos * improve docs * remove :causal * Update src/attention.jl * add function barrier
- Loading branch information
1 parent
2ef2daa
commit 1203b21
Showing
7 changed files
with
249 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,144 @@ | ||
const AA3{T} = AbstractArray{T,3} | ||
const AA4{T} = AbstractArray{T,4} | ||
const AA{N,T} = AbstractArray{T,N} | ||
|
||
""" | ||
dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads]) | ||
Multihead dot product attention used in transformer architectures. | ||
The input arrays must have the first two dimensions given by the number of features | ||
and the sequence length, then an arbitrary number of batch dimensions or none. | ||
Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores | ||
of size `(kv_len, q_len, nheads, batch_size...)`. | ||
See also [`dot_product_attention_scores`](@ref) if you only need the attention scores. | ||
# Arguments | ||
- `query`: Query array of size `(qk_dim, q_len, batch_size...)`. | ||
- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`. | ||
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`. | ||
- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. | ||
It will be added to the attention scores before applying the softmax. Default `nothing`. | ||
- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax. | ||
Default `identity` (no dropout). | ||
- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`. | ||
The mask is applied to the attention scores just before the softmax. | ||
See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`. | ||
- `nheads`: Number of heads to split the input arrays into. Default `1`. | ||
# Examples | ||
```julia | ||
q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2) | ||
y, α = dot_product_attention(q, k, v) | ||
``` | ||
""" | ||
function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N | ||
batch_size = size(q)[3:end] | ||
batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same.")) | ||
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v)) | ||
|
||
x, α = dot_product_attention(q, k, v, args...; kws...) | ||
|
||
x = reshape(x, size(x, 1), size(x, 2), batch_size...) | ||
α = reshape(α, size(α)[1:3]..., batch_size...) | ||
return x, α | ||
end | ||
|
||
function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing; | ||
fdrop=identity, mask=nothing, nheads=1) | ||
|
||
(size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same.")) | ||
size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same.")) | ||
size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same.")) | ||
|
||
# Multihead attention. TODO create fastpath for singlehead attention. | ||
q, k, v = split_heads.((q, k, v), nheads) | ||
x, α = _dot_product_attention(q, k, v, bias, fdrop, mask) | ||
return join_heads(x), α | ||
end | ||
|
||
function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask) | ||
# [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size] | ||
# [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size] | ||
# [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size] | ||
|
||
α = dot_product_attention_scores(q, k, bias; fdrop, mask) | ||
# [α] = [kv_len, q_len, nheads, batch_size] | ||
|
||
# The following permutedims and batched_mul are equivalent to | ||
# @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b] | ||
vt = permutedims(v, (1, 3, 2, 4)) | ||
x = batched_mul(vt, α) | ||
x = permutedims(x, (1, 3, 2, 4)) | ||
# [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size] | ||
return x, α | ||
end | ||
|
||
""" | ||
dot_product_attention_scores(query, key, [bias]; [fdrop, mask]) | ||
Return the attention scores for the [`dot_product_attention`](@ref). | ||
Input arrays must have dimensions | ||
`(num_features ÷ nheads, nheads, sequence_length, batch_size)`. | ||
See [`dot_product_attention`](@ref) for more details. | ||
""" | ||
function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing; | ||
fdrop=identity, mask=nothing) where T | ||
|
||
# The following permutedims and batched_mul are equivalent to | ||
# @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim) | ||
kt = permutedims(k, (3, 1, 2, 4)) | ||
qt = permutedims(q, (1, 3, 2, 4)) ./ √T(size(q, 1)) | ||
logits = batched_mul(kt, qt) | ||
# [logits] = [kv_len, q_len, nheads, batch_size] | ||
|
||
logits = apply_attn_bias(logits, bias) | ||
logits = apply_attn_mask(logits, mask) | ||
|
||
α = softmax(logits, dims=1) | ||
return fdrop(α) | ||
end | ||
|
||
apply_attn_bias(logits, bias::Nothing) = logits | ||
|
||
apply_attn_bias(logits, bias) = logits .+ bias | ||
|
||
|
||
apply_attn_mask(logits, mask::Nothing) = logits | ||
|
||
function apply_attn_mask(logits, mask) | ||
neginf = typemin(eltype(logits)) | ||
ifelse.(mask, logits, neginf) | ||
end | ||
|
||
|
||
""" | ||
make_causal_mask(x, dims=2) | ||
Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`. | ||
Its elements are set such that `m[i, j] == i ≤ j`. | ||
Can be used to mask the attention scores in [`dot_product_attention`](@ref). | ||
""" | ||
function make_causal_mask(x::AbstractArray; dims::Int=2) | ||
len = size(x, dims) | ||
mask = triu(trues_like(x, (len, len))) | ||
return mask | ||
end | ||
|
||
trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true) | ||
falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false) | ||
|
||
split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...) | ||
join_heads(x) = reshape(x, :, size(x)[3:end]...) | ||
|
||
@non_differentiable make_causal_mask(::Any...) | ||
@non_differentiable trues_like(::Any...) | ||
@non_differentiable falses_like(::Any...) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -138,7 +138,7 @@ for (gemm, elt) in gemm_datatype_mappings | |
|
||
end | ||
|
||
C | ||
return C | ||
end | ||
end | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
@testset "different batchsizes" begin | ||
n = 15 | ||
lenq = 3 | ||
lenkv = 4 | ||
for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5] | ||
q = rand(Float32, n, lenq, batch_size...) | ||
k = rand(Float32, n, lenkv, batch_size...) | ||
v = rand(Float32, n, lenkv, batch_size...) | ||
y, α = dot_product_attention(q, k, v; nheads) | ||
@test y isa Array{Float32} | ||
@test size(y) == (n, lenq, batch_size...) | ||
@test size(α) == (lenkv, lenq, nheads, batch_size...) | ||
@test sum(α, dims=1) ≈ ones(1, lenq, nheads, batch_size...) | ||
end | ||
end | ||
|
||
@testset "dot_product_attention_scores" begin | ||
q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24 | ||
α = dot_product_attention_scores(q, k) | ||
q2, k2 = reshape.((q, k), 8, 3, 1) | ||
y, α2 = dot_product_attention(q2, k2, k2; nheads=2) | ||
@test α ≈ α2 | ||
end | ||
|
||
@testset "specific results" begin | ||
q = k = v = reshape([1:12;], 4, 3, 1) ./ 12 | ||
y, α = dot_product_attention(q, k, v; nheads=2) | ||
ytrue = [0.429754, 0.513087, 0.613791, 0.697125, 0.46431, 0.547644, 0.647876, 0.73121, 0.49773, 0.581064, 0.680455, 0.763788] | ||
ytrue = reshape(ytrue, 4, 3, 1) | ||
αtrue = [0.313896, 0.332948, 0.353157, 0.264431, 0.328206, 0.407362, 0.219215, 0.31838, 0.462405, 0.288691, 0.331243, 0.380066, 0.241239, 0.323893, 0.434868, 0.198438, 0.311761, 0.489801] | ||
αtrue = reshape(αtrue, 3, 3, 2, 1) | ||
@test y ≈ ytrue atol=1e-5 | ||
@test α ≈ αtrue atol=1e-5 | ||
end | ||
|
||
@testset "mask" begin | ||
q = rand(4, 2, 3, 1) | ||
k = rand(4, 2, 5, 1) | ||
|
||
mask = rand(Bool, (5, 3)) | ||
α = dot_product_attention_scores(q, k; mask) | ||
@test all((α[:,:,1,1].> 0) .== mask) | ||
@test all((α[:,:,2,1].> 0) .== mask) | ||
|
||
@testset "causal" begin | ||
x = rand(4, 2, 3, 1) | ||
mask = make_causal_mask(x, dims=3) | ||
α = dot_product_attention_scores(x, x; mask) | ||
@test all((α[:,:,1,1].> 0) .== mask) | ||
@test all((α[:,:,2,1].> 0) .== mask) | ||
end | ||
end | ||
|
||
@testset "dropout" begin | ||
q = k = v = rand(10, 10, 10) | ||
fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p) | ||
y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5)) | ||
@test 0.6 > mean(>(0), α) > 0.4 | ||
end | ||
|
||
@testset "bias" begin | ||
q = rand(4, 5, 1) | ||
k = v = rand(4, 3, 1) | ||
bias = randn(3, 5) | ||
y, α = dot_product_attention(q, k, v, bias; nheads=2) | ||
@test size(α) == (3, 5, 2, 1) | ||
@test size(y) == (4, 5, 1) | ||
end | ||
|
||
@testset "gradient" begin | ||
q = rand(4, 5, 1) | ||
k = v = rand(4, 3, 1) | ||
bias = randn(3, 5) | ||
y, α = dot_product_attention(q, k, v, bias; nheads=2) | ||
gradtest((x...) -> dot_product_attention(x...; nheads=2)[1], q, k, v, bias) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters