Skip to content

Commit

Permalink
Formatting, and some tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Jul 30, 2022
1 parent 99eb25a commit ced84a4
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 78 deletions.
1 change: 1 addition & 0 deletions src/convnets/mobilenet/mobilenetv1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ function mobilenetv1(width_mult, config;
Dense(inchannels, nclasses)))
end

# Layer configurations for MobileNetv1
const MOBILENETV1_CONFIGS = [
# dw, c, s, r
(false, 32, 2, 1),
Expand Down
3 changes: 1 addition & 2 deletions src/convnets/mobilenet/mobilenetv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ end

# Layer configurations for MobileNetv2
const MOBILENETV2_CONFIGS = [
# t, c, n, s, a
# t, c, n, s, a
(1, 16, 1, 1, relu6),
(6, 24, 2, 2, relu6),
(6, 32, 3, 2, relu6),
Expand All @@ -57,7 +57,6 @@ const MOBILENETV2_CONFIGS = [
(6, 320, 1, 1, relu6),
]

# Model definition for MobileNetv2
struct MobileNetv2
layers::Any
end
Expand Down
67 changes: 33 additions & 34 deletions src/convnets/mobilenet/mobilenetv3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,41 +52,40 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla
Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier))
end

# Configurations for small and large mode for MobileNetv3
MOBILENETV3_CONFIGS = Dict(:small => [
# k, t, c, SE, a, s
(3, 1, 16, 4, relu, 2),
(3, 4.5, 24, nothing, relu, 2),
(3, 3.67, 24, nothing, relu, 1),
(5, 4, 40, 4, hardswish, 2),
(5, 6, 40, 4, hardswish, 1),
(5, 6, 40, 4, hardswish, 1),
(5, 3, 48, 4, hardswish, 1),
(5, 3, 48, 4, hardswish, 1),
(5, 6, 96, 4, hardswish, 2),
(5, 6, 96, 4, hardswish, 1),
(5, 6, 96, 4, hardswish, 1),
],
:large => [
# k, t, c, SE, a, s
(3, 1, 16, nothing, relu, 1),
(3, 4, 24, nothing, relu, 2),
(3, 3, 24, nothing, relu, 1),
(5, 3, 40, 4, relu, 2),
(5, 3, 40, 4, relu, 1),
(5, 3, 40, 4, relu, 1),
(3, 6, 80, nothing, hardswish, 2),
(3, 2.5, 80, nothing, hardswish, 1),
(3, 2.3, 80, nothing, hardswish, 1),
(3, 2.3, 80, nothing, hardswish, 1),
(3, 6, 112, 4, hardswish, 1),
(3, 6, 112, 4, hardswish, 1),
(5, 6, 160, 4, hardswish, 2),
(5, 6, 160, 4, hardswish, 1),
(5, 6, 160, 4, hardswish, 1),
])
# Layer configurations for small and large models for MobileNetv3
const MOBILENETV3_CONFIGS = Dict(:small => [
# k, t, c, SE, a, s
(3, 1, 16, 4, relu, 2),
(3, 4.5, 24, nothing, relu, 2),
(3, 3.67, 24, nothing, relu, 1),
(5, 4, 40, 4, hardswish, 2),
(5, 6, 40, 4, hardswish, 1),
(5, 6, 40, 4, hardswish, 1),
(5, 3, 48, 4, hardswish, 1),
(5, 3, 48, 4, hardswish, 1),
(5, 6, 96, 4, hardswish, 2),
(5, 6, 96, 4, hardswish, 1),
(5, 6, 96, 4, hardswish, 1),
],
:large => [
# k, t, c, SE, a, s
(3, 1, 16, nothing, relu, 1),
(3, 4, 24, nothing, relu, 2),
(3, 3, 24, nothing, relu, 1),
(5, 3, 40, 4, relu, 2),
(5, 3, 40, 4, relu, 1),
(5, 3, 40, 4, relu, 1),
(3, 6, 80, nothing, hardswish, 2),
(3, 2.5, 80, nothing, hardswish, 1),
(3, 2.3, 80, nothing, hardswish, 1),
(3, 2.3, 80, nothing, hardswish, 1),
(3, 6, 112, 4, hardswish, 1),
(3, 6, 112, 4, hardswish, 1),
(5, 6, 160, 4, hardswish, 2),
(5, 6, 160, 4, hardswish, 1),
(5, 6, 160, 4, hardswish, 1),
])

# Model definition for MobileNetv3
struct MobileNetv3
layers::Any
end
Expand Down
12 changes: 8 additions & 4 deletions src/convnets/resnets/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,8 @@ function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer
drop_block = DropBlock(blockschedule[schedule_idx])
block = basicblock(inplanes, planes; stride, reduction_factor, activation,
norm_layer, revnorm, attn_fn, drop_path, drop_block)
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm)
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
revnorm)
# inplanes increases by expansion after each block
inplanes = planes * expansion
return block, downsample
Expand Down Expand Up @@ -248,7 +249,8 @@ function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
reduction_factor, activation, norm_layer, revnorm,
attn_fn, drop_path, drop_block)
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm)
downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer,
revnorm)
# inplanes increases by expansion after each block
inplanes = planes * expansion
return block, downsample
Expand Down Expand Up @@ -298,13 +300,15 @@ function resnet(block_type::Symbol, block_repeats::Vector{<:Integer};
get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor,
activation, norm_layer, revnorm, attn_fn,
drop_block_rate, drop_path_rate,
stride_fn = resnet_stride, planes_fn = resnet_planes,
stride_fn = resnet_stride,
planes_fn = resnet_planes,
downsample_tuple = downsample_opt)
elseif block_type == :bottleneck
get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width,
reduction_factor, activation, norm_layer,
revnorm, attn_fn, drop_block_rate, drop_path_rate,
stride_fn = resnet_stride, planes_fn = resnet_planes,
stride_fn = resnet_stride,
planes_fn = resnet_planes,
downsample_tuple = downsample_opt)
else
# TODO: write better message when we have link to dev docs for resnet
Expand Down
20 changes: 10 additions & 10 deletions src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ include("../utilities.jl")
include("attention.jl")
export MHAttention

include("conv.jl")
export conv_norm, depthwise_sep_conv_norm, invertedresidual

include("drop.jl")
export DropBlock, DropPath

include("embeddings.jl")
export PatchEmbedding, ViPosEmbedding, ClassTokens

Expand All @@ -25,19 +31,13 @@ export mlp_block, gated_mlp_block, create_fc, create_classifier
include("normalise.jl")
export prenorm, ChannelLayerNorm

include("conv.jl")
export conv_norm, depthwise_sep_conv_norm, invertedresidual

include("drop.jl")
export DropBlock, DropPath, droppath_rates

include("selayers.jl")
export squeeze_excite, effective_squeeze_excite
include("pool.jl")
export AdaptiveMeanMaxPool

include("scale.jl")
export LayerScale, inputscale

include("pool.jl")
export AdaptiveMeanMaxPool
include("selayers.jl")
export squeeze_excite, effective_squeeze_excite

end
39 changes: 13 additions & 26 deletions src/layers/attention.jl
Original file line number Diff line number Diff line change
@@ -1,46 +1,33 @@
"""
MHAttention(nheads::Integer, qkv_layer, attn_drop_rate, projection)
MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_dropout_rate = 0., proj_dropout_rate = 0.)
Multi-head self-attention layer.
# Arguments
- `nheads`: Number of heads
- `qkv_layer`: layer to be used for getting the query, key and value
- `attn_drop_rate`: dropout rate after the self-attention layer
- `projection`: projection layer to be used after self-attention
- `planes`: number of input channels
- `nheads`: number of heads
- `qkv_bias`: whether to use bias in the layer to get the query, key and value
- `attn_dropout_rate`: dropout rate after the self-attention layer
- `proj_dropout_rate`: dropout rate after the projection layer
"""
struct MHAttention{P, Q, R}
nheads::Int
qkv_layer::P
attn_drop_rate::Q
attn_drop::Q
projection::R
end
@functor MHAttention

"""
MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop_rate = 0., proj_drop_rate = 0.)
Multi-head self-attention layer.
# Arguments
- `planes`: number of input channels
- `nheads`: number of heads
- `qkv_bias`: whether to use bias in the layer to get the query, key and value
- `attn_drop_rate`: dropout rate after the self-attention layer
- `proj_drop_rate`: dropout rate after the projection layer
"""
function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false,
attn_drop_rate = 0.0, proj_drop_rate = 0.0)
attn_dropout_rate = 0.0, proj_dropout_rate = 0.0)
@assert planes % nheads==0 "planes should be divisible by nheads"
qkv_layer = Dense(planes, planes * 3; bias = qkv_bias)
attn_drop_rate = Dropout(attn_drop_rate)
proj = Chain(Dense(planes, planes), Dropout(proj_drop_rate))
return MHAttention(nheads, qkv_layer, attn_drop_rate, proj)
attn_drop = Dropout(attn_dropout_rate)
proj = Chain(Dense(planes, planes), Dropout(proj_dropout_rate))
return MHAttention(nheads, qkv_layer, attn_drop, proj)
end

@functor MHAttention

function (m::MHAttention)(x::AbstractArray{T, 3}) where {T}
nfeatures, seq_len, batch_size = size(x)
x_reshaped = reshape(x, nfeatures, seq_len * batch_size)
Expand All @@ -52,7 +39,7 @@ function (m::MHAttention)(x::AbstractArray{T, 3}) where {T}
seq_len * batch_size)
query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
m.nheads, seq_len * batch_size)
attention = m.attn_drop_rate(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale))
value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads,
m.nheads, seq_len * batch_size)
pre_projection = reshape(batched_mul(attention, value_reshaped),
Expand Down
3 changes: 1 addition & 2 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu
return revnorm ? reverse(layers) : layers
end

function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, outplanes,
activation = identity; kwargs...)
function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; kwargs...)
inplanes, outplanes = ch
return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...)
end
Expand Down
4 changes: 4 additions & 0 deletions src/layers/mlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,10 @@ function create_classifier(inplanes, nclasses; pool_layer = AdaptiveMeanPool((1,
"Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used"
end
flatten_in_pool = !use_conv && pool_layer !== identity
if use_conv
@assert pool_layer === identity
"`pool_layer` must be identity if `use_conv` is true"
end
global_pool = flatten_in_pool ? Chain(pool_layer, MLUtils.flatten) : pool_layer
# Fully-connected layer
fc = use_conv ? Conv((1, 1), inplanes => nclasses) : Dense(inplanes => nclasses)
Expand Down

0 comments on commit ced84a4

Please sign in to comment.