Skip to content

Commit

Permalink
implement dot_product_attention (#455)
Browse files Browse the repository at this point in the history
* add dot_product_attention

* run tests

* docs

* address some review comments

* fix tests

* fix fdrop

* additional method

* bias is positional argument

* test bias

* fix tests on julia 1.6

* typos

* improve docs

* remove :causal

* Update src/attention.jl

* add function barrier
  • Loading branch information
CarloLucibello authored Feb 3, 2023
1 parent 2ef2daa commit 1203b21
Show file tree
Hide file tree
Showing 7 changed files with 249 additions and 3 deletions.
8 changes: 8 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ tanhshrink
trelu
```

## Attention

```@docs
dot_product_attention
dot_product_attention_scores
make_causal_mask
```

## Softmax

`Flux`'s `logitcrossentropy` uses `NNlib.softmax` internally.
Expand Down
3 changes: 3 additions & 0 deletions src/NNlib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ for f in ACTIVATIONS
end
export sigmoid, hardsigmoid, logsigmoid, thresholdrelu # Aliases

include("attention.jl")
export dot_product_attention, dot_product_attention_scores, make_causal_mask

include("dropout.jl")
export dropout, dropout!

Expand Down
144 changes: 144 additions & 0 deletions src/attention.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
const AA3{T} = AbstractArray{T,3}
const AA4{T} = AbstractArray{T,4}
const AA{N,T} = AbstractArray{T,N}

"""
dot_product_attention(query, key, value, [bias]; [fdrop, mask, nheads])
Multihead dot product attention used in transformer architectures.
The input arrays must have the first two dimensions given by the number of features
and the sequence length, then an arbitrary number of batch dimensions or none.
Returns the attention output array of size `(v_dim, q_len, batch_size...)` and the attention scores
of size `(kv_len, q_len, nheads, batch_size...)`.
See also [`dot_product_attention_scores`](@ref) if you only need the attention scores.
# Arguments
- `query`: Query array of size `(qk_dim, q_len, batch_size...)`.
- `key`: Key array of size `(qk_dim, kv_len, batch_size...)`.
- `value`: Value array of size `(v_dim, kv_len, batch_size...)`.
- `bias`: Either `nothing` or an array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
It will be added to the attention scores before applying the softmax. Default `nothing`.
- `fdrop`: A dropout function or layer to be applied on the attention scores right after the softmax.
Default `identity` (no dropout).
- `mask`: Either `nothing` or a boolean array broadcastable to size `(kv_len, q_len, nheads, batch_size)`.
The mask is applied to the attention scores just before the softmax.
See [`make_causal_mask`](@ref) fore creating causal masks. Default `nothing`.
- `nheads`: Number of heads to split the input arrays into. Default `1`.
# Examples
```julia
q, k, v = rand(10, 20, 2), rand(10, 30, 2), rand(20, 30, 2)
y, α = dot_product_attention(q, k, v)
```
"""
function dot_product_attention(q::AA{N}, k::AA{N}, v::AA{N}, args...; kws...) where N
batch_size = size(q)[3:end]
batch_size == size(k)[3:end] == size(v)[3:end] || throw(ArgumentError("Batch dimensions have to be the same."))
q, k, v = map(x -> reshape(x, size(x, 1), size(x, 2), :), (q, k, v))

x, α = dot_product_attention(q, k, v, args...; kws...)

x = reshape(x, size(x, 1), size(x, 2), batch_size...)
α = reshape(α, size(α)[1:3]..., batch_size...)
return x, α
end

function dot_product_attention(q::AA3, k::AA3, v::AA3, bias=nothing;
fdrop=identity, mask=nothing, nheads=1)

(size(q, 3) == size(k, 3) == size(v, 3)) || throw(ArgumentError("Batch dimensions have to be the same."))
size(q, 1) == size(k, 1) || throw(ArgumentError("First dimension in query and key has to be the same."))
size(k, 2) == size(v, 2) || throw(ArgumentError("Second dimension in key and value has to be the same."))

# Multihead attention. TODO create fastpath for singlehead attention.
q, k, v = split_heads.((q, k, v), nheads)
x, α = _dot_product_attention(q, k, v, bias, fdrop, mask)
return join_heads(x), α
end

function _dot_product_attention(q::AA4, k::AA4, v::AA4, bias, fdrop, mask)
# [q] = [qk_dim ÷ nheads, nheads, q_len, batch_size]
# [k] = [qk_dim ÷ nheads, nheads, kv_len, batch_size]
# [v] = [v_dim ÷ nheads, nheads, kv_len, batch_size]

α = dot_product_attention_scores(q, k, bias; fdrop, mask)
# [α] = [kv_len, q_len, nheads, batch_size]

# The following permutedims and batched_mul are equivalent to
# @tullio x[d, h, i, b] := α[j, i, h, b] * v[d, h, j, b]
vt = permutedims(v, (1, 3, 2, 4))
x = batched_mul(vt, α)
x = permutedims(x, (1, 3, 2, 4))
# [x] = [kv_dim ÷ nheads, nheads, q_len, batch_size]
return x, α
end

"""
dot_product_attention_scores(query, key, [bias]; [fdrop, mask])
Return the attention scores for the [`dot_product_attention`](@ref).
Input arrays must have dimensions
`(num_features ÷ nheads, nheads, sequence_length, batch_size)`.
See [`dot_product_attention`](@ref) for more details.
"""
function dot_product_attention_scores(q::AA4{T}, k::AA4{T}, bias=nothing;
fdrop=identity, mask=nothing) where T

# The following permutedims and batched_mul are equivalent to
# @tullio logits[j, i, h, b] := q[d, h, i, b] * k[d, h, j, b] / √T(qk_dim)
kt = permutedims(k, (3, 1, 2, 4))
qt = permutedims(q, (1, 3, 2, 4)) ./ T(size(q, 1))
logits = batched_mul(kt, qt)
# [logits] = [kv_len, q_len, nheads, batch_size]

logits = apply_attn_bias(logits, bias)
logits = apply_attn_mask(logits, mask)

α = softmax(logits, dims=1)
return fdrop(α)
end

apply_attn_bias(logits, bias::Nothing) = logits

apply_attn_bias(logits, bias) = logits .+ bias


apply_attn_mask(logits, mask::Nothing) = logits

function apply_attn_mask(logits, mask)
neginf = typemin(eltype(logits))
ifelse.(mask, logits, neginf)
end


"""
make_causal_mask(x, dims=2)
Return a boolean square matrix `m` of the same type as `x` and of side `size(x, dims)`.
Its elements are set such that `m[i, j] == i ≤ j`.
Can be used to mask the attention scores in [`dot_product_attention`](@ref).
"""
function make_causal_mask(x::AbstractArray; dims::Int=2)
len = size(x, dims)
mask = triu(trues_like(x, (len, len)))
return mask
end

trues_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), true)
falses_like(x::AbstractArray, sz=size(x)) = fill!(similar(x, Bool, sz), false)

split_heads(x, nheads) = reshape(x, size(x, 1) ÷ nheads, nheads, size(x)[2:end]...)
join_heads(x) = reshape(x, :, size(x)[3:end]...)

@non_differentiable make_causal_mask(::Any...)
@non_differentiable trues_like(::Any...)
@non_differentiable falses_like(::Any...)

15 changes: 13 additions & 2 deletions src/batched/batchedmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@ _unbatch(A::BatchedAdjOrTrans) = parent(A)
batched_mul(A, B) -> C
A ⊠ B # \\boxtimes
Batched matrix multiplication. Result has `C[:,:,k] == A[:,:,k] * B[:,:,k]` for all `k`.
If `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.
Batched matrix multiplication. Result has `C[:,:,k...] == A[:,:,k...] * B[:,:,k...]` where `k...` represent
any indices in the last dimensions.
If `ndims(A) == ndims(B) == 3` and `size(B,3) == 1` then instead `C[:,:,k] == A[:,:,k] * B[:,:,1]`, and similarly for `A`.
To transpose each matrix, apply `batched_transpose` to the array,
or `batched_adjoint` for conjugate-transpose:
Expand Down Expand Up @@ -42,6 +44,15 @@ This will be copied, as doing so is faster than `batched_mul_generic!`.
Both this `copy` and `batched_mul_generic!` produce `@debug` messages,
and setting for instance `ENV["JULIA_DEBUG"] = NNlib` will display them.
"""
function batched_mul(x::AbstractArray{T1,N}, y::AbstractArray{T2,N}) where {T1,T2,N}
batch_size = size(x)[3:end]
@assert batch_size == size(y)[3:end] "batch size has to be the same for the two arrays."
x2 = reshape(x, size(x, 1), size(x, 2), :)
y2 = reshape(y, size(y, 1), size(y, 2), :)
z = batched_mul(x2, y2)
return reshape(z, size(z, 1), size(z, 2), batch_size...)
end

function batched_mul(A::AbstractArray{T1, 3}, B::AbstractArray{T2, 3}) where {T1, T2}
size(A, 3) == size(B, 3) || size(A, 3) == 1 || size(B, 3) == 1 ||
throw(DimensionMismatch("batch size mismatch: A != B"))
Expand Down
2 changes: 1 addition & 1 deletion src/gemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ for (gemm, elt) in gemm_datatype_mappings

end

C
return C
end
end
end
76 changes: 76 additions & 0 deletions test/attention.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
@testset "different batchsizes" begin
n = 15
lenq = 3
lenkv = 4
for batch_size in [(), 1, 2, (2,1,3)], nheads in [1, 3, 5]
q = rand(Float32, n, lenq, batch_size...)
k = rand(Float32, n, lenkv, batch_size...)
v = rand(Float32, n, lenkv, batch_size...)
y, α = dot_product_attention(q, k, v; nheads)
@test y isa Array{Float32}
@test size(y) == (n, lenq, batch_size...)
@test size(α) == (lenkv, lenq, nheads, batch_size...)
@test sum(α, dims=1) ones(1, lenq, nheads, batch_size...)
end
end

@testset "dot_product_attention_scores" begin
q = k = reshape([1:24;], 4, 2, 3, 1) ./ 24
α = dot_product_attention_scores(q, k)
q2, k2 = reshape.((q, k), 8, 3, 1)
y, α2 = dot_product_attention(q2, k2, k2; nheads=2)
@test α α2
end

@testset "specific results" begin
q = k = v = reshape([1:12;], 4, 3, 1) ./ 12
y, α = dot_product_attention(q, k, v; nheads=2)
ytrue = [0.429754, 0.513087, 0.613791, 0.697125, 0.46431, 0.547644, 0.647876, 0.73121, 0.49773, 0.581064, 0.680455, 0.763788]
ytrue = reshape(ytrue, 4, 3, 1)
αtrue = [0.313896, 0.332948, 0.353157, 0.264431, 0.328206, 0.407362, 0.219215, 0.31838, 0.462405, 0.288691, 0.331243, 0.380066, 0.241239, 0.323893, 0.434868, 0.198438, 0.311761, 0.489801]
αtrue = reshapetrue, 3, 3, 2, 1)
@test y ytrue atol=1e-5
@test α αtrue atol=1e-5
end

@testset "mask" begin
q = rand(4, 2, 3, 1)
k = rand(4, 2, 5, 1)

mask = rand(Bool, (5, 3))
α = dot_product_attention_scores(q, k; mask)
@test all((α[:,:,1,1].> 0) .== mask)
@test all((α[:,:,2,1].> 0) .== mask)

@testset "causal" begin
x = rand(4, 2, 3, 1)
mask = make_causal_mask(x, dims=3)
α = dot_product_attention_scores(x, x; mask)
@test all((α[:,:,1,1].> 0) .== mask)
@test all((α[:,:,2,1].> 0) .== mask)
end
end

@testset "dropout" begin
q = k = v = rand(10, 10, 10)
fdrop(x, p) = (rand!(similar(x)) .> p) .* x ./ (1-p)
y, α = dot_product_attention(q, k, v; nheads=2, fdrop = x -> fdrop(x, 0.5))
@test 0.6 > mean(>(0), α) > 0.4
end

@testset "bias" begin
q = rand(4, 5, 1)
k = v = rand(4, 3, 1)
bias = randn(3, 5)
y, α = dot_product_attention(q, k, v, bias; nheads=2)
@test size(α) == (3, 5, 2, 1)
@test size(y) == (4, 5, 1)
end

@testset "gradient" begin
q = rand(4, 5, 1)
k = v = rand(4, 3, 1)
bias = randn(3, 5)
y, α = dot_product_attention(q, k, v, bias; nheads=2)
gradtest((x...) -> dot_product_attention(x...; nheads=2)[1], q, k, v, bias)
end
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ include("test_utils.jl")
include("activations.jl")
end

@testset "Attention" begin
include("attention.jl")
end

@testset "Batched Multiplication" begin
include("batchedmul.jl")
end
Expand Down

0 comments on commit 1203b21

Please sign in to comment.