diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index 9c832d7f0..76d1ef036 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -279,7 +279,7 @@ function inceptionv4_c() end """ - inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000) + inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) Create an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -287,10 +287,10 @@ Create an Inceptionv4 model. # Arguments - `inchannels`: number of input channels. - - `dropout`: rate of dropout in classifier head. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000) +function inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., conv_bn((3, 3), 32, 32)..., conv_bn((3, 3), 32, 64; pad = 1)..., @@ -313,12 +313,13 @@ function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000) inceptionv4_c(), inceptionv4_c(), inceptionv4_c()) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses)) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + Dense(1536, nclasses)) return Chain(body, head) end """ - Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) + Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) Creates an Inceptionv4 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -327,7 +328,7 @@ Creates an Inceptionv4 model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - `inchannels`: number of input channels. - - `dropout`: rate of dropout in classifier head. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning @@ -338,7 +339,7 @@ struct Inceptionv4 layers::Any end -function Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) +function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = inceptionv4(; inchannels, dropout, nclasses) pretrain && loadpretrain!(layers, "Inceptionv4") return Inceptionv4(layers) @@ -419,7 +420,7 @@ function block8(scale = 1.0f0; activation = identity) end """ - inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000) + inceptionresnetv2(; inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -427,10 +428,10 @@ Creates an InceptionResNetv2 model. # Arguments - `inchannels`: number of input channels. - - `dropout`: rate of dropout in classifier head. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000) +function inceptionresnetv2(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)..., conv_bn((3, 3), 32, 32)..., conv_bn((3, 3), 32, 64; pad = 1)..., @@ -446,12 +447,13 @@ function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000) [block8(0.20f0) for _ in 1:9]..., block8(; activation = relu), conv_bn((1, 1), 2080, 1536)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses)) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + Dense(1536, nclasses)) return Chain(body, head) end """ - InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) + InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an InceptionResNetv2 model. ([reference](https://arxiv.org/abs/1602.07261)) @@ -460,7 +462,7 @@ Creates an InceptionResNetv2 model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet - `inchannels`: number of input channels. - - `dropout`: rate of dropout in classifier head. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning @@ -471,9 +473,9 @@ struct InceptionResNetv2 layers::Any end -function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0, +function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) - layers = inceptionresnetv2(; inchannels, dropout, nclasses) + layers = inceptionresnetv2(; inchannels, drop_rate, nclasses) pretrain && loadpretrain!(layers, "InceptionResNetv2") return InceptionResNetv2(layers) end @@ -533,7 +535,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, end """ - xception(; inchannels = 3, dropout = 0.0, nclasses = 1000) + xception(; inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) @@ -541,10 +543,10 @@ Creates an Xception model. # Arguments - `inchannels`: number of input channels. - - `dropout`: rate of dropout in classifier head. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. """ -function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000) +function xception(; inchannels = 3, drop_rate = 0.0, nclasses = 1000) body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2, bias = false)..., conv_bn((3, 3), 32, 64; bias = false)..., xception_block(64, 128, 2; stride = 2, start_with_relu = false), @@ -554,7 +556,8 @@ function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000) xception_block(728, 1024, 2; stride = 2, grow_at_start = false), depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)..., depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...) - head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(2048, nclasses)) + head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate), + Dense(2048, nclasses)) return Chain(body, head) end @@ -563,7 +566,7 @@ struct Xception end """ - Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) + Xception(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000) Creates an Xception model. ([reference](https://arxiv.org/abs/1610.02357)) @@ -572,15 +575,15 @@ Creates an Xception model. - `pretrain`: set to `true` to load the pre-trained weights for ImageNet. - `inchannels`: number of input channels. - - `dropout`: rate of dropout in classifier head. + - `drop_rate`: rate of dropout in classifier head. - `nclasses`: the number of output classes. !!! warning `Xception` does not currently support pretrained weights. """ -function Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000) - layers = xception(; inchannels, dropout, nclasses) +function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) + layers = xception(; inchannels, drop_rate, nclasses) pretrain && loadpretrain!(layers, "xception") return Xception(layers) end diff --git a/src/convnets/resnet.jl b/src/convnets/resnet.jl index 2b78f09ad..d025d1d5d 100644 --- a/src/convnets/resnet.jl +++ b/src/convnets/resnet.jl @@ -1,9 +1,42 @@ +function drop_blocks(drop_prob = 0.0) + return [ + identity, + identity, + DropBlock(drop_prob, 5, 0.25), + DropBlock(drop_prob, 3, 1.00), + ] +end + +function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size + first_dilation = kernel_size[1] > 1 ? + (!isnothing(first_dilation) ? first_dilation : dilation) : 1 + pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 + return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, + dilation = first_dilation, bias = false), + norm_layer(out_channels)) +end + +function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, + first_dilation = nothing, norm_layer = BatchNorm) + avg_stride = dilation == 1 ? stride : 1 + if stride == 1 && dilation == 1 + pool = identity + else + pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 + pool = avg_pool_fn((2, 2); stride = avg_stride, pad) + end + return Chain(pool, + Conv((1, 1), in_channels => out_channels; bias = false), + norm_layer(out_channels)) +end + function basicblock(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, - reduce_first = 1, dilation = 1, first_dilation = nothing, - act_layer = relu, norm_layer = BatchNorm, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity) - expansion = 1 + expansion = expansion_factor(basicblock) @assert cardinality==1 "BasicBlock only supports cardinality of 1" @assert base_width==64 "BasicBlock does not support changing base width" first_planes = planes ÷ reduce_first @@ -17,16 +50,16 @@ function basicblock(inplanes, planes; stride = 1, downsample = identity, cardina dilation = dilation, bias = false), norm_layer(outplanes)) return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_path)), - act_layer) + Chain(conv_bn1, drop_block, activation, conv_bn2, drop_path)), + activation) end +expansion_factor(::typeof(basicblock)) = 1 function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1, - base_width = 64, - reduce_first = 1, dilation = 1, first_dilation = nothing, - act_layer = relu, norm_layer = BatchNorm, + base_width = 64, reduce_first = 1, dilation = 1, + first_dilation = nothing, activation = relu, norm_layer = BatchNorm, drop_block = identity, drop_path = identity) - expansion = 4 + expansion = expansion_factor(bottleneck) width = floor(Int, planes * (base_width / 64)) * cardinality first_planes = width ÷ reduce_first outplanes = planes * expansion @@ -39,55 +72,25 @@ function bottleneck(inplanes, planes; stride = 1, downsample = identity, cardina drop_block = drop_block === identity ? identity : drop_block() conv_bn3 = Chain(Conv((1, 1), width => outplanes; bias = false), norm_layer(outplanes)) return Chain(Parallel(+, downsample, - Chain(conv_bn1, drop_block, act_layer, conv_bn2, drop_block, - act_layer, conv_bn3, drop_path)), - act_layer) + Chain(conv_bn1, drop_block, activation, conv_bn2, drop_block, + activation, conv_bn3, drop_path)), + activation) end +expansion_factor(::typeof(bottleneck)) = 4 -function drop_blocks(drop_prob = 0.0) - return [identity, identity, - drop_prob == 0.0 ? DropBlock(drop_prob, 5, 0.25) : identity, - drop_prob == 0.0 ? DropBlock(drop_prob, 3, 1.00) : identity] -end - -function downsample_conv(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - kernel_size = stride == 1 && dilation == 1 ? 1 : kernel_size - first_dilation = kernel_size[1] > 1 ? - (!isnothing(first_dilation) ? first_dilation : dilation) : 1 - pad = ((stride - 1) + dilation * (kernel_size[1] - 1)) ÷ 2 - return Chain(Conv(kernel_size, in_channels => out_channels; stride, pad, - dilation = first_dilation, bias = false), - norm_layer(out_channels)) -end - -function downsample_avg(kernel_size, in_channels, out_channels; stride = 1, dilation = 1, - first_dilation = nothing, norm_layer = BatchNorm) - avg_stride = dilation == 1 ? stride : 1 - if stride == 1 && dilation == 1 - pool = identity - else - pad = avg_stride == 1 && dilation > 1 ? SamePad() : 0 - pool = avg_pool_fn((2, 2); stride = avg_stride, pad) - end - - return Chain(pool, - Conv((1, 1), in_channels => out_channels; stride = 1, pad = 0, - bias = false), - norm_layer(out_channels)) -end - -function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, - reduce_first = 1, output_stride = 32, - down_kernel_size = 1, avg_down = false, drop_block_rate = 0.0, - drop_path_rate = 0.0, kwargs...) +function make_blocks(block_fn, channels, block_repeats, inplanes; + reduce_first = 1, output_stride = 32, down_kernel_size = 1, + avg_down = false, drop_block_rate = 0.0, drop_path_rate = 0.0, + kwargs...) + expansion = expansion_factor(block_fn) kwarg_dict = Dict(kwargs...) stages = [] net_block_idx = 1 net_stride = 4 dilation = prev_dilation = 1 - for (stage_idx, (planes, num_blocks, db)) in enumerate(zip(channels, block_repeats, - drop_blocks(drop_block_rate))) + for (stage_idx, (planes, num_blocks, drop_block)) in enumerate(zip(channels, + block_repeats, + drop_blocks(drop_block_rate))) stride = stage_idx == 1 ? 1 : 2 if net_stride >= output_stride dilation *= stride @@ -95,6 +98,7 @@ function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, else net_stride *= stride end + # first block needs to be handled differently for downsampling downsample = identity if stride != 1 || inplanes != planes * expansion downsample = avg_down ? @@ -106,7 +110,7 @@ function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, norm_layer = kwarg_dict[:norm_layer]) end block_kwargs = Dict(:reduce_first => reduce_first, :dilation => dilation, - :drop_block => db, kwargs...) + :drop_block => drop_block, kwargs...) blocks = [] for block_idx in 1:num_blocks downsample = block_idx == 1 ? downsample : identity @@ -127,15 +131,13 @@ function make_blocks(block_fn, channels, block_repeats, inplanes; expansion = 1, end function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride = 32, - expansion = 1, cardinality = 1, base_width = 64, stem_width = 64, stem_type = :default, - replace_stem_pool = false, reduce_first = 1, - down_kernel_size = (1, 1), avg_down = false, act_layer = relu, - norm_layer = BatchNorm, + replace_stem_pool = false, reduce_first = 1, down_kernel_size = (1, 1), + avg_down = false, activation = relu, norm_layer = BatchNorm, drop_rate = 0.0, drop_path_rate = 0.0, drop_block_rate = 0.0, block_kwargs...) - @assert output_stride in (8, 16, 32) - @assert stem_type in [:default, :deep, :deep_tiered] + @assert output_stride in (8, 16, 32) "Invalid `output_stride`. Must be one of (8, 16, 32)" + @assert stem_type in [:default, :deep, :deep_tiered] "Stem type must be one of [:default, :deep, :deep_tiered]" # Stem inplanes = stem_type == :deep ? stem_width * 2 : 64 if stem_type == :deep @@ -145,40 +147,40 @@ function resnet(block, layers; num_classes = 1000, inchannels = 3, output_stride end conv1 = Chain(Conv((3, 3), inchannels => stem_channels[0]; stride = 2, pad = 1, bias = false), - norm_layer(stem_channels[1]), - act_layer(), - Conv((3, 3), stem_channels[1] => stem_channels[1]; stride = 1, - pad = 1, bias = false), - norm_layer(stem_channels[2]), - act_layer(), - Conv((3, 3), stem_channels[2] => inplanes; stride = 1, pad = 1, - bias = false)) + norm_layer(stem_channels[1], activation), + Conv((3, 3), stem_channels[1] => stem_channels[1]; pad = 1, + bias = false), + norm_layer(stem_channels[2], activation), + Conv((3, 3), stem_channels[2] => inplanes; pad = 1, bias = false)) else conv1 = Conv((7, 7), inchannels => inplanes; stride = 2, pad = 3, bias = false) end - bn1 = norm_layer(inplanes) - act1 = act_layer + bn1 = norm_layer(inplanes, activation) # Stem pooling if replace_stem_pool stempool = Chain(Conv((3, 3), inplanes => inplanes; stride = 2, pad = 1, bias = false), - norm_layer(inplanes), - act_layer) + norm_layer(inplanes, activation)) else stempool = MaxPool((3, 3); stride = 2, pad = 1) end - stem = Chain(conv1, bn1, act1, stempool) - + stem = Chain(conv1, bn1, stempool) # Feature Blocks channels = [64, 128, 256, 512] stage_blocks = make_blocks(block, channels, layers, inplanes; cardinality, base_width, output_stride, reduce_first, avg_down, - down_kernel_size, act_layer, norm_layer, + down_kernel_size, activation, norm_layer, drop_block_rate, drop_path_rate, block_kwargs...) - # Head (Pooling and Classifier) + expansion = expansion_factor(block) num_features = 512 * expansion classifier = Chain(GlobalMeanPool(), Dropout(drop_rate), MLUtils.flatten, Dense(num_features, num_classes)) return Chain(Chain(stem, stage_blocks), classifier) end + +struct ResNet + layers::Any +end + +function ResNet(depth::Integer = 50; pretrain = false, nclasses = 1000) end diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 48aca7365..ced9a992a 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -52,7 +52,7 @@ function vgg_convolutional_layers(config, batchnorm, inchannels) end """ - vgg_classifier_layers(imsize, nclasses, fcsize, dropout) + vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) Create VGG classifier (fully connected) layers ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -63,19 +63,19 @@ Create VGG classifier (fully connected) layers the convolution layers (see [`Metalhead.vgg_convolutional_layers`](#)) - `nclasses`: number of output classes - `fcsize`: input and output size of the intermediate fully connected layer - - `dropout`: the dropout level between each fully connected layer + - `drop_rate`: the dropout level between each fully connected layer """ -function vgg_classifier_layers(imsize, nclasses, fcsize, dropout) +function vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) return Chain(MLUtils.flatten, Dense(Int(prod(imsize)), fcsize, relu), - Dropout(dropout), + Dropout(drop_rate), Dense(fcsize, fcsize, relu), - Dropout(dropout), + Dropout(drop_rate), Dense(fcsize, nclasses)) end """ - vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) + vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) Create a VGG model ([reference](https://arxiv.org/abs/1409.1556v6)). @@ -90,12 +90,12 @@ Create a VGG model - `nclasses`: number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `dropout`: dropout level between fully connected layers + - `drop_rate`: dropout level between fully connected layers """ -function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) +function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) conv = vgg_convolutional_layers(config, batchnorm, inchannels) imsize = outputsize(conv, (imsize..., inchannels); padbatch = true)[1:3] - class = vgg_classifier_layers(imsize, nclasses, fcsize, dropout) + class = vgg_classifier_layers(imsize, nclasses, fcsize, drop_rate) return Chain(Chain(conv), class) end @@ -114,7 +114,7 @@ struct VGG end """ - VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, dropout) + VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, drop_rate) Construct a VGG model with the specified input image size. Typically, the image size is `(224, 224)`. @@ -126,17 +126,11 @@ Construct a VGG model with the specified input image size. Typically, the image - `nclasses`::Integer : number of output classes - `fcsize`: intermediate fully connected layer size (see [`Metalhead.vgg_classifier_layers`](#)) - - `dropout`: dropout level between fully connected layers + - `drop_rate`: dropout level between fully connected layers """ -function VGG(imsize::Dims{2}; - config, inchannels, batchnorm = false, nclasses, fcsize, dropout) - layers = vgg(imsize; config = config, - inchannels = inchannels, - batchnorm = batchnorm, - nclasses = nclasses, - fcsize = fcsize, - dropout = dropout) - +function VGG(imsize::Dims{2}; config, inchannels, batchnorm = false, nclasses, fcsize, + drop_rate) + layers = vgg(imsize; config, inchannels, batchnorm, nclasses, fcsize, drop_rate) return VGG(layers) end @@ -169,7 +163,7 @@ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses batchnorm = batchnorm, nclasses = nclasses, fcsize = 4096, - dropout = 0.5) + drop_rate = 0.5) if pretrain && !batchnorm loadpretrain!(model, string("vgg", depth)) elseif pretrain diff --git a/src/layers/attention.jl b/src/layers/attention.jl index a1244a033..b6e7b7678 100644 --- a/src/layers/attention.jl +++ b/src/layers/attention.jl @@ -7,18 +7,18 @@ Multi-head self-attention layer. - `nheads`: Number of heads - `qkv_layer`: layer to be used for getting the query, key and value - - `attn_drop`: dropout rate after the self-attention layer + - `attn_drop_rate`: dropout rate after the self-attention layer - `projection`: projection layer to be used after self-attention """ struct MHAttention{P, Q, R} nheads::Int qkv_layer::P - attn_drop::Q + attn_drop_rate::Q projection::R end """ - MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop = 0., proj_drop = 0.) + MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, attn_drop_rate = 0., proj_drop_rate = 0.) Multi-head self-attention layer. @@ -27,15 +27,15 @@ Multi-head self-attention layer. - `planes`: number of input channels - `nheads`: number of heads - `qkv_bias`: whether to use bias in the layer to get the query, key and value - - `attn_drop`: dropout rate after the self-attention layer - - `proj_drop`: dropout rate after the projection layer + - `attn_drop_rate`: dropout rate after the self-attention layer + - `proj_drop_rate`: dropout rate after the projection layer """ function MHAttention(planes::Integer, nheads::Integer = 8; qkv_bias::Bool = false, - attn_drop = 0.0, proj_drop = 0.0) + attn_drop_rate = 0.0, proj_drop_rate = 0.0) @assert planes % nheads==0 "planes should be divisible by nheads" qkv_layer = Dense(planes, planes * 3; bias = qkv_bias) - attn_drop = Dropout(attn_drop) - proj = Chain(Dense(planes, planes), Dropout(proj_drop)) + attn_drop = Dropout(attn_drop_rate) + proj = Chain(Dense(planes, planes), Dropout(proj_drop_rate)) return MHAttention(nheads, qkv_layer, attn_drop, proj) end diff --git a/src/layers/drop.jl b/src/layers/drop.jl index b3f9a8719..8e6202085 100644 --- a/src/layers/drop.jl +++ b/src/layers/drop.jl @@ -54,11 +54,12 @@ end """ DropPath(p) -Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0. +Implements Stochastic Depth - equivalent to `Dropout(p; dims = 4)` when `p` ≥ 0 and +`identity` otherwise. ([reference](https://arxiv.org/abs/1603.09382)) # Arguments - `p`: rate of Stochastic Depth. """ -DropPath(p) = p ≥ 0 ? Dropout(p; dims = 4) : identity +DropPath(p) = p > 0 ? Dropout(p; dims = 4) : identity diff --git a/src/layers/mlp-linear.jl b/src/layers/mlp-linear.jl index e282e2632..550c2ad22 100644 --- a/src/layers/mlp-linear.jl +++ b/src/layers/mlp-linear.jl @@ -15,7 +15,7 @@ end """ mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout = 0., activation = gelu) + drop_rate =0., activation = gelu) Feedforward block used in many MLPMixer-like and vision-transformer models. @@ -24,18 +24,18 @@ Feedforward block used in many MLPMixer-like and vision-transformer models. - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `dropout`: Dropout rate. + - `drop_rate`: Dropout rate. - `activation`: Activation function to use. """ function mlp_block(inplanes::Integer, hidden_planes::Integer, outplanes::Integer = inplanes; - dropout = 0.0, activation = gelu) - return Chain(Dense(inplanes, hidden_planes, activation), Dropout(dropout), - Dense(hidden_planes, outplanes), Dropout(dropout)) + drop_rate = 0.0, activation = gelu) + return Chain(Dense(inplanes, hidden_planes, activation), Dropout(drop_rate), + Dense(hidden_planes, outplanes), Dropout(drop_rate)) end """ gated_mlp(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; dropout = 0., activation = gelu) + outplanes::Integer = inplanes; drop_rate = 0.0, activation = gelu) Feedforward block based on the implementation in the paper "Pay Attention to MLPs". ([reference](https://arxiv.org/abs/2105.08050)) @@ -46,16 +46,16 @@ Feedforward block based on the implementation in the paper "Pay Attention to MLP - `inplanes`: Number of dimensions in the input. - `hidden_planes`: Number of dimensions in the intermediate layer. - `outplanes`: Number of dimensions in the output - by default it is the same as `inplanes`. - - `dropout`: Dropout rate. + - `drop_rate`: Dropout rate. - `activation`: Activation function to use. """ function gated_mlp_block(gate_layer, inplanes::Integer, hidden_planes::Integer, - outplanes::Integer = inplanes; dropout = 0.0, activation = gelu) + outplanes::Integer = inplanes; drop_rate = 0.0, activation = gelu) @assert hidden_planes % 2==0 "`hidden_planes` must be even for gated MLP" return Chain(Dense(inplanes, hidden_planes, activation), - Dropout(dropout), + Dropout(drop_rate), gate_layer(hidden_planes), Dense(hidden_planes ÷ 2, outplanes), - Dropout(dropout)) + Dropout(drop_rate)) end gated_mlp_block(::typeof(identity), args...; kwargs...) = mlp_block(args...; kwargs...) diff --git a/src/other/mlpmixer.jl b/src/other/mlpmixer.jl index 942abc823..ed4c47af3 100644 --- a/src/other/mlpmixer.jl +++ b/src/other/mlpmixer.jl @@ -1,6 +1,6 @@ """ mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout = 0., drop_path_rate = 0., activation = gelu) + drop_rate =0., drop_path_rate = 0., activation = gelu) Creates a feedforward block for the MLPMixer architecture. ([reference](https://arxiv.org/pdf/2105.01601)) @@ -12,20 +12,22 @@ Creates a feedforward block for the MLPMixer architecture. - `mlp_ratio`: number(s) that determine(s) the number of hidden channels in the token mixing MLP and/or the channel mixing MLP as a ratio to the number of planes in the block. - `mlp_layer`: the MLP layer to use in the block - - `dropout`: the dropout rate to use in the MLP blocks + - `drop_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks """ function mixerblock(planes, npatches; mlp_ratio = (0.5, 4.0), mlp_layer = mlp_block, - dropout = 0.0, drop_path_rate = 0.0, activation = gelu) + drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu) tokenplanes, channelplanes = [Int(r * planes) for r in mlp_ratio] return Chain(SkipConnection(Chain(LayerNorm(planes), swapdims((2, 1, 3)), - mlp_layer(npatches, tokenplanes; activation, dropout), + mlp_layer(npatches, tokenplanes; activation, + drop_rate), swapdims((2, 1, 3)), DropPath(drop_path_rate)), +), SkipConnection(Chain(LayerNorm(planes), - mlp_layer(planes, channelplanes; activation, dropout), + mlp_layer(planes, channelplanes; activation, + drop_rate), DropPath(drop_path_rate)), +)) end @@ -113,7 +115,7 @@ backbone(m::MLPMixer) = m.layers[1] classifier(m::MLPMixer) = m.layers[2] """ - resmixerblock(planes, npatches; dropout = 0., drop_path_rate = 0., mlp_ratio = 4.0, + resmixerblock(planes, npatches; drop_rate =0., drop_path_rate = 0., mlp_ratio = 4.0, activation = gelu, λ = 1e-4) Creates a block for the ResMixer architecture. @@ -126,13 +128,13 @@ Creates a block for the ResMixer architecture. - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `mlp_layer`: the MLP block to use - - `dropout`: the dropout rate to use in the MLP blocks + - `drop_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks - `λ`: initialisation constant for the LayerScale """ function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, - dropout = 0.0, drop_path_rate = 0.0, activation = gelu, λ = 1e-4) + drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu, λ = 1e-4) return Chain(SkipConnection(Chain(Flux.Scale(planes), swapdims((2, 1, 3)), Dense(npatches, npatches), @@ -140,7 +142,7 @@ function resmixerblock(planes, npatches; mlp_ratio = 4.0, mlp_layer = mlp_block, LayerScale(planes, λ), DropPath(drop_path_rate)), +), SkipConnection(Chain(Flux.Scale(planes), - mlp_layer(planes, Int(mlp_ratio * planes); dropout, + mlp_layer(planes, Int(mlp_ratio * planes); drop_rate, activation), LayerScale(planes, λ), DropPath(drop_path_rate)), +)) @@ -230,7 +232,7 @@ end """ spatial_gating_block(planes, npatches; mlp_ratio = 4.0, mlp_layer = gated_mlp_block, - norm_layer = LayerNorm, dropout = 0.0, drop_path_rate = 0., + norm_layer = LayerNorm, drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu) Creates a feedforward block based on the gMLP model architecture described in the paper. @@ -243,18 +245,19 @@ Creates a feedforward block based on the gMLP model architecture described in th - `mlp_ratio`: ratio of the number of hidden channels in the channel mixing MLP to the number of planes in the block - `norm_layer`: the normalisation layer to use - - `dropout`: the dropout rate to use in the MLP blocks + - `drop_rate`: the dropout rate to use in the MLP blocks - `drop_path_rate`: Stochastic depth rate - `activation`: the activation function to use in the MLP blocks """ function spatial_gating_block(planes, npatches; mlp_ratio = 4.0, norm_layer = LayerNorm, - mlp_layer = gated_mlp_block, dropout = 0.0, + mlp_layer = gated_mlp_block, drop_rate = 0.0, drop_path_rate = 0.0, activation = gelu) channelplanes = Int(mlp_ratio * planes) sgu = inplanes -> SpatialGatingUnit(inplanes, npatches; norm_layer) return SkipConnection(Chain(norm_layer(planes), - mlp_layer(sgu, planes, channelplanes; activation, dropout), + mlp_layer(sgu, planes, channelplanes; activation, + drop_rate), DropPath(drop_path_rate)), +) end diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index dffc93ccf..4b479052a 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -1,5 +1,5 @@ """ -transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.) +transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, drop_rate =0.) Transformer as used in the base ViT architecture. ([reference](https://arxiv.org/abs/2010.11929)). @@ -10,23 +10,24 @@ Transformer as used in the base ViT architecture. - `depth`: number of attention blocks - `nheads`: number of attention heads - `mlp_ratio`: ratio of MLP layers to the number of input channels - - `dropout`: dropout rate + - `drop_rate`: dropout rate """ -function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, dropout = 0.0) +function transformer_encoder(planes, depth, nheads; mlp_ratio = 4.0, drop_rate = 0.0) layers = [Chain(SkipConnection(prenorm(planes, - MHAttention(planes, nheads; attn_drop = dropout, - proj_drop = dropout)), +), + MHAttention(planes, nheads; + attn_drop_rate = drop_rate, + proj_drop_rate = drop_rate)), +), SkipConnection(prenorm(planes, mlp_block(planes, floor(Int, mlp_ratio * planes); - dropout)), +)) + drop_rate)), +)) for _ in 1:depth] return Chain(layers) end """ vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1, - emb_dropout = 0.1, pool = :class, nclasses = 1000) + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, drop_rate = 0.1, + emb_drop_rate = 0.1, pool = :class, nclasses = 1000) Creates a Vision Transformer (ViT) model. ([reference](https://arxiv.org/abs/2010.11929)). @@ -40,22 +41,23 @@ Creates a Vision Transformer (ViT) model. - `depth`: number of blocks in the transformer - `nheads`: number of attention heads in the transformer - `mlpplanes`: number of hidden channels in the MLP block in the transformer - - `dropout`: dropout rate + - `drop_rate`: dropout rate - `emb_dropout`: dropout rate for the positional embedding layer - `pool`: pooling type, either :class or :mean - `nclasses`: number of classes in the output """ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = (16, 16), - embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, dropout = 0.1, - emb_dropout = 0.1, pool = :class, nclasses = 1000) + embedplanes = 768, depth = 6, nheads = 16, mlp_ratio = 4.0, drop_rate = 0.1, + emb_drop_rate = 0.1, pool = :class, nclasses = 1000) @assert pool in [:class, :mean] "Pool type must be either :class (class token) or :mean (mean pooling)" npatches = prod(imsize .÷ patch_size) return Chain(Chain(PatchEmbedding(imsize; inchannels, patch_size, embedplanes), ClassTokens(embedplanes), ViPosEmbedding(embedplanes, npatches + 1), - Dropout(emb_dropout), - transformer_encoder(embedplanes, depth, nheads; mlp_ratio, dropout), + Dropout(emb_drop_rate), + transformer_encoder(embedplanes, depth, nheads; mlp_ratio, + drop_rate), (pool == :class) ? x -> x[:, 1, :] : seconddimmean), Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end @@ -97,7 +99,6 @@ function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3, @assert mode in keys(vit_configs) "`mode` must be one of $(keys(vit_configs))" kwargs = vit_configs[mode] layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...) - return ViT(layers) end