From ced84a46253fe1090c6b597b6d2ccb1fd6ee9fec Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 30 Jul 2022 11:48:14 +0530 Subject: [PATCH] Formatting, and some tweaks --- src/convnets/mobilenet/mobilenetv1.jl | 1 + src/convnets/mobilenet/mobilenetv2.jl | 3 +- src/convnets/mobilenet/mobilenetv3.jl | 67 +++++++++++++-------------- src/convnets/resnets/core.jl | 12 +++-- src/layers/Layers.jl | 20 ++++---- src/layers/attention.jl | 39 ++++++---------- src/layers/conv.jl | 3 +- src/layers/mlp.jl | 4 ++ 8 files changed, 71 insertions(+), 78 deletions(-) diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl index fe075d5ef..fffa93a4d 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -45,6 +45,7 @@ function mobilenetv1(width_mult, config; Dense(inchannels, nclasses))) end +# Layer configurations for MobileNetv1 const MOBILENETV1_CONFIGS = [ # dw, c, s, r (false, 32, 2, 1), diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index dd9eda012..a97e7dda1 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -47,7 +47,7 @@ end # Layer configurations for MobileNetv2 const MOBILENETV2_CONFIGS = [ - # t, c, n, s, a + # t, c, n, s, a (1, 16, 1, 1, relu6), (6, 24, 2, 2, relu6), (6, 32, 3, 2, relu6), @@ -57,7 +57,6 @@ const MOBILENETV2_CONFIGS = [ (6, 320, 1, 1, relu6), ] -# Model definition for MobileNetv2 struct MobileNetv2 layers::Any end diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 5a06f6be5..d8666c5f3 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -52,41 +52,40 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla Chain(AdaptiveMeanPool((1, 1)), MLUtils.flatten, classifier)) end -# Configurations for small and large mode for MobileNetv3 -MOBILENETV3_CONFIGS = Dict(:small => [ - # k, t, c, SE, a, s - (3, 1, 16, 4, relu, 2), - (3, 4.5, 24, nothing, relu, 2), - (3, 3.67, 24, nothing, relu, 1), - (5, 4, 40, 4, hardswish, 2), - (5, 6, 40, 4, hardswish, 1), - (5, 6, 40, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 3, 48, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 2), - (5, 6, 96, 4, hardswish, 1), - (5, 6, 96, 4, hardswish, 1), - ], - :large => [ - # k, t, c, SE, a, s - (3, 1, 16, nothing, relu, 1), - (3, 4, 24, nothing, relu, 2), - (3, 3, 24, nothing, relu, 1), - (5, 3, 40, 4, relu, 2), - (5, 3, 40, 4, relu, 1), - (5, 3, 40, 4, relu, 1), - (3, 6, 80, nothing, hardswish, 2), - (3, 2.5, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 2.3, 80, nothing, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (3, 6, 112, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 2), - (5, 6, 160, 4, hardswish, 1), - (5, 6, 160, 4, hardswish, 1), - ]) +# Layer configurations for small and large models for MobileNetv3 +const MOBILENETV3_CONFIGS = Dict(:small => [ + # k, t, c, SE, a, s + (3, 1, 16, 4, relu, 2), + (3, 4.5, 24, nothing, relu, 2), + (3, 3.67, 24, nothing, relu, 1), + (5, 4, 40, 4, hardswish, 2), + (5, 6, 40, 4, hardswish, 1), + (5, 6, 40, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 3, 48, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 2), + (5, 6, 96, 4, hardswish, 1), + (5, 6, 96, 4, hardswish, 1), + ], + :large => [ + # k, t, c, SE, a, s + (3, 1, 16, nothing, relu, 1), + (3, 4, 24, nothing, relu, 2), + (3, 3, 24, nothing, relu, 1), + (5, 3, 40, 4, relu, 2), + (5, 3, 40, 4, relu, 1), + (5, 3, 40, 4, relu, 1), + (3, 6, 80, nothing, hardswish, 2), + (3, 2.5, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 2.3, 80, nothing, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (3, 6, 112, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 2), + (5, 6, 160, 4, hardswish, 1), + (5, 6, 160, 4, hardswish, 1), + ]) -# Model definition for MobileNetv3 struct MobileNetv3 layers::Any end diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index aa7309a9b..03d96d6db 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -215,7 +215,8 @@ function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer drop_block = DropBlock(blockschedule[schedule_idx]) block = basicblock(inplanes, planes; stride, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) # inplanes increases by expansion after each block inplanes = planes * expansion return block, downsample @@ -248,7 +249,8 @@ function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer block = bottleneck(inplanes, planes; stride, cardinality, base_width, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_path, drop_block) - downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, revnorm) + downsample = downsample_fn(inplanes, planes * expansion; stride, norm_layer, + revnorm) # inplanes increases by expansion after each block inplanes = planes * expansion return block, downsample @@ -298,13 +300,15 @@ function resnet(block_type::Symbol, block_repeats::Vector{<:Integer}; get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, planes_fn = resnet_planes, + stride_fn = resnet_stride, + planes_fn = resnet_planes, downsample_tuple = downsample_opt) elseif block_type == :bottleneck get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width, reduction_factor, activation, norm_layer, revnorm, attn_fn, drop_block_rate, drop_path_rate, - stride_fn = resnet_stride, planes_fn = resnet_planes, + stride_fn = resnet_stride, + planes_fn = resnet_planes, downsample_tuple = downsample_opt) else # TODO: write better message when we have link to dev docs for resnet diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index 3db3a2ccd..04be476ff 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -16,6 +16,12 @@ include("../utilities.jl") include("attention.jl") export MHAttention +include("conv.jl") +export conv_norm, depthwise_sep_conv_norm, invertedresidual + +include("drop.jl") +export DropBlock, DropPath + include("embeddings.jl") export PatchEmbedding, ViPosEmbedding, ClassTokens @@ -25,19 +31,13 @@ export mlp_block, gated_mlp_block, create_fc, create_classifier include("normalise.jl") export prenorm, ChannelLayerNorm -include("conv.jl") -export conv_norm, depthwise_sep_conv_norm, invertedresidual - -include("drop.jl") -export DropBlock, DropPath, droppath_rates - -include("selayers.jl") -export squeeze_excite, effective_squeeze_excite +include("pool.jl") +export AdaptiveMeanMaxPool include("scale.jl") export LayerScale, inputscale -include("pool.jl") -export AdaptiveMeanMaxPool +include("selayers.jl") +export squeeze_excite, effective_squeeze_excite end diff --git a/src/layers/attention.jl b/src/layers/attention.jl index 7d8ee776d..e2276aa01 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -1,46 +1,33 @@ """ - MHAttention(nheads::Integer, qkv_layer, attn_drop_rate, projection) + MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_dropout_rate = 0., proj_dropout_rate = 0.) Multi-head self-attention layer. # Arguments - - `nheads`: Number of heads - - `qkv_layer`: layer to be used for getting the query, key and value - - `attn_drop_rate`: dropout rate after the self-attention layer - - `projection`: projection layer to be used after self-attention + - `planes`: number of input channels + - `nheads`: number of heads + - `qkv_bias`: whether to use bias in the layer to get the query, key and value + - `attn_dropout_rate`: dropout rate after the self-attention layer + - `proj_dropout_rate`: dropout rate after the projection layer """ struct MHAttention{P, Q, R} nheads::Int qkv_layer::P - attn_drop_rate::Q + attn_drop::Q projection::R end +@functor MHAttention -""" - MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop_rate = 0., proj_drop_rate = 0.) - -Multi-head self-attention layer. - -# Arguments - - - `planes`: number of input channels - - `nheads`: number of heads - - `qkv_bias`: whether to use bias in the layer to get the query, key and value - - `attn_drop_rate`: dropout rate after the self-attention layer - - `proj_drop_rate`: dropout rate after the projection layer -""" function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, - attn_drop_rate = 0.0, proj_drop_rate = 0.0) + attn_dropout_rate = 0.0, proj_dropout_rate = 0.0) @assert planes % nheads==0 "planes should be divisible by nheads" qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) - attn_drop_rate = Dropout(attn_drop_rate) - proj = Chain(Dense(planes, planes), Dropout(proj_drop_rate)) - return MHAttention(nheads, qkv_layer, attn_drop_rate, proj) + attn_drop = Dropout(attn_dropout_rate) + proj = Chain(Dense(planes, planes), Dropout(proj_dropout_rate)) + return MHAttention(nheads, qkv_layer, attn_drop, proj) end -@functor MHAttention - function (m::MHAttention)(x::AbstractArray{T, 3}) where {T} nfeatures, seq_len, batch_size = size(x) x_reshaped = reshape(x, nfeatures, seq_len * batch_size) @@ -52,7 +39,7 @@ function (m::MHAttention)(x::AbstractArray{T, 3}) where {T} seq_len * batch_size) query_reshaped = reshape(permutedims(query, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size) - attention = m.attn_drop_rate(softmax(batched_mul(query_reshaped, key_reshaped) .* scale)) + attention = m.attn_drop(softmax(batched_mul(query_reshaped, key_reshaped) .* scale)) value_reshaped = reshape(permutedims(value, (1, 2, 3, 4)), nfeatures ÷ m.nheads, m.nheads, seq_len * batch_size) pre_projection = reshape(batched_mul(attention, value_reshaped), diff --git a/src/layers/conv.jl b/src/layers/conv.jl index f5d94fbcb..02d80d67a 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -52,8 +52,7 @@ function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu return revnorm ? reverse(layers) : layers end -function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, outplanes, - activation = identity; kwargs...) +function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, activation = identity; kwargs...) inplanes, outplanes = ch return conv_norm(kernel_size, inplanes, outplanes, activation; kwargs...) end diff --git a/src/layers/mlp.jl b/src/layers/mlp.jl index 4a623c977..a3bdb0fb5 100644 --- a/src/layers/mlp.jl +++ b/src/layers/mlp.jl @@ -69,6 +69,10 @@ function create_classifier(inplanes, nclasses; pool_layer = AdaptiveMeanPool((1, "Pooling can only be disabled if classifier is also removed or a convolution-based classifier is used" end flatten_in_pool = !use_conv && pool_layer !== identity + if use_conv + @assert pool_layer === identity + "`pool_layer` must be identity if `use_conv` is true" + end global_pool = flatten_in_pool ? Chain(pool_layer, MLUtils.flatten) : pool_layer # Fully-connected layer fc = use_conv ? Conv((1, 1), inplanes => nclasses) : Dense(inplanes => nclasses)