Skip to content

Commit

Permalink
Tweaks - I
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Jun 27, 2022
1 parent a038ff8 commit de079bc
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/convnets/inception.jl
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ struct Inceptionv4
end

function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000)
layers = inceptionv4(; inchannels, dropout, nclasses)
layers = inceptionv4(; inchannels, drop_rate, nclasses)
pretrain && loadpretrain!(layers, "Inceptionv4")
return Inceptionv4(layers)
end
Expand Down
8 changes: 4 additions & 4 deletions src/layers/attention.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""
MHAttention(nheads::Integer, qkv_layer, attn_drop, projection)
MHAttention(nheads::Integer, qkv_layer, attn_drop_rate, projection)
Multi-head self-attention layer.
Expand Down Expand Up @@ -34,9 +34,9 @@ function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = fals
attn_drop_rate = 0.0, proj_drop_rate = 0.0)
@assert planes % nheads==0 "planes should be divisible by nheads"
qkv_layer = Dense(planes, planes * 3; bias = qkv_bias)
attn_drop = Dropout(attn_drop_rate)
attn_drop_rate = Dropout(attn_drop_rate)
proj = Chain(Dense(planes, planes), Dropout(proj_drop_rate))
return MHAttention(nheads, qkv_layer, attn_drop, proj)
return MHAttention(nheads, qkv_layer, attn_drop_rate, proj)
end

@functor MHAttention
Expand All @@ -52,7 +52,7 @@ function (m::MHAttention)(x::AbstractArray{T, 3}) where {T}
seq_len * batch_size)
query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
m.nheads, seq_len * batch_size)
attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
attention = m.attn_drop_rate(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
m.nheads, seq_len * batch_size)
pre_projection = reshape(batched_mul(attention, value_reshaped),
Expand Down
2 changes: 1 addition & 1 deletion src/layers/normalise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,4 @@ function ChannelLayerNorm(sz::Integer, λ = identity; ϵ = 1.0f-5)
return ChannelLayerNorm(diag, ϵ)
end

(m::ChannelLayerNorm)(x) = m.diag(MLUtils.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ))
(m::ChannelLayerNorm)(x) = m.diag(Flux.normalise(x; dims = ndims(x) - 1, ϵ = m.ϵ))

0 comments on commit de079bc

Please sign in to comment.