Skip to content


Fifth refactor is a charm
Browse files Browse the repository at this point in the history
Also, we aren't using the skips anymore
  • Loading branch information
theabhirath committed Jul 29, 2022
1 parent aa2a9ef commit b143b95
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 146 deletions.
212 changes: 113 additions & 99 deletions src/convnets/resnets/core.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
basicblock(inplanes, planes; stride = 1, downsample = identity,
reduction_factor = 1, dilation = 1, first_dilation = dilation,
activation = relu, connection = addact\$activation,
norm_layer = BatchNorm, drop_block = identity, drop_path = identity,
attn_fn = planes -> identity)
basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu,
norm_layer = BatchNorm, prenorm = false,
drop_block = identity, drop_path = identity,
attn_fn = planes -> identity)
Creates a basic ResNet block.
Expand All @@ -12,24 +11,19 @@ Creates a basic ResNet block.
- `inplanes`: number of input feature maps
- `planes`: number of feature maps for the block
- `stride`: the stride of the block
- `downsample`: the downsampling function to use
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
- `dilation`: the dilation of the second convolution.
- `first_dilation`: the dilation of the first convolution.
- `activation`: the activation function to use.
- `connection`: the function applied to the output of residual and skip paths in
a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses
PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`.
- `norm_layer`: the normalization layer to use.
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
function and passed in.
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
function and passed in.
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
function basicblock(inplanes, planes; stride = 1, reduction_factor = 1, activation = relu,
norm_layer = BatchNorm, prenorm = false,
function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
reduction_factor::Integer = 1, activation = relu,
norm_layer = BatchNorm, prenorm::Bool = false,
drop_block = identity, drop_path = identity,
attn_fn = planes -> identity)
first_planes = planes ÷ reduction_factor
Expand All @@ -45,11 +39,11 @@ end
expansion_factor(::typeof(basicblock)) = 1

bottleneck(inplanes, planes; stride = 1, downsample = identity, cardinality = 1,
base_width = 64, reduction_factor = 1, first_dilation = 1,
activation = relu, connection = addact\$activation,
norm_layer = BatchNorm, drop_block = identity, drop_path = identity,
attn_fn = planes -> identity)
bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64,
reduction_factor = 1, activation = relu,
norm_layer = BatchNorm, prenorm = false,
drop_block = identity, drop_path = identity,
attn_fn = planes -> identity)
Creates a bottleneck ResNet block.
Expand All @@ -58,26 +52,22 @@ Creates a bottleneck ResNet block.
- `inplanes`: number of input feature maps
- `planes`: number of feature maps for the block
- `stride`: the stride of the block
- `downsample`: the downsampling function to use
- `cardinality`: the number of groups in the convolution.
- `base_width`: the number of output feature maps for each convolutional group.
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
- `first_dilation`: the dilation of the 3x3 convolution.
- `activation`: the activation function to use.
- `connection`: the function applied to the output of residual and skip paths in
a block. See [`addact`](#) and [`actadd`](#) for an example. Note that this uses
PartialFunctions.jl to pass in the activation function with the notation `addact\$activation`.
- `norm_layer`: the normalization layer to use.
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
function and passed in.
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
function and passed in.
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
function bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64,
reduction_factor = 1, activation = relu,
norm_layer = BatchNorm, prenorm = false,
function bottleneck(inplanes::Integer, planes::Integer; stride::Integer,
cardinality::Integer = 1, base_width::Integer = 64,
reduction_factor::Integer = 1, activation = relu,
norm_layer = BatchNorm, prenorm::Bool = false,
drop_block = identity, drop_path = identity,
attn_fn = planes -> identity)
width = floor(Int, planes * (base_width / 64)) * cardinality
Expand Down Expand Up @@ -113,6 +103,7 @@ end

# Downsample layer which is an identity projection. Uses max pooling
# when the output size is more than the input size.
# TODO - figure out how to make this work when outplanes < inplanes
function downsample_identity(inplanes::Integer, outplanes::Integer; kwargs...)
if outplanes > inplanes
return Chain(MaxPool((1, 1); stride = 2),
Expand Down Expand Up @@ -174,8 +165,8 @@ on how to use this function.
- `activation`: The activation function used in the stem.
function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3,
replace_pool::Bool = false, norm_layer = BatchNorm, prenorm = false,
activation = relu)
replace_pool::Bool = false, activation = relu,
norm_layer = BatchNorm, prenorm::Bool = false)
@assert stem_type in [:default, :deep, :deep_tiered]
"Stem type must be one of [:default, :deep, :deep_tiered]"
# Main stem
Expand Down Expand Up @@ -203,65 +194,70 @@ function resnet_stem(stem_type::Symbol = :default; inchannels::Integer = 3,
stride = 2, pad = 1, bias = false)...) :
MaxPool((3, 3); stride = 2, pad = 1)
return Chain(conv1, bn1, stempool), inplanes

# Templating builders for the blocks and the downsampling layers
function template_builder(block_fn; kwargs...)
function (inplanes, planes; _kwargs...)
return block_fn(inplanes, planes; kwargs..., _kwargs...)

function template_builder(::typeof(basicblock); reduction_factor::Integer = 1,
activation = relu, norm_layer = BatchNorm, prenorm::Bool = false,
attn_fn = planes -> identity, kargs...)
return (args...; kwargs...) -> basicblock(args...; kwargs..., reduction_factor,
activation, norm_layer, prenorm, attn_fn)
return Chain(conv1, bn1, stempool)

function template_builder(::typeof(bottleneck); cardinality::Integer = 1,
base_width::Integer = 64,
reduction_factor::Integer = 1, activation = relu,
norm_layer = BatchNorm, prenorm::Bool = false,
attn_fn = planes -> identity, kargs...)
return (args...; kwargs...) -> bottleneck(args...; kwargs..., cardinality, base_width,
reduction_factor, activation,
norm_layer, prenorm, attn_fn)
resnet_planes(stage_idx::Integer) = 64 * 2^(stage_idx - 1)

function template_builder(downsample_fn::Union{typeof(downsample_conv),
norm_layer = BatchNorm, prenorm = false)
return (args...; kwargs...) -> downsample_fn(args...; kwargs..., norm_layer, prenorm)
function basicblock_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64,
reduction_factor::Integer = 1, expansion::Integer = 1,
norm_layer = BatchNorm, prenorm::Bool = false,
activation = relu, attn_fn = planes -> identity,
drop_block_rate = 0.0, drop_path_rate = 0.0,
stride_fn = get_stride, planes_fn = resnet_planes,
downsample_tuple = (downsample_conv, downsample_identity))
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
# closure over `idxs`
function get_layers(stage_idx::Integer, block_idx::Integer)
planes = planes_fn(stage_idx)
# `get_stride` is a callback that the user can tweak to change the stride of the
# blocks. It defaults to the standard behaviour as in the paper
stride = stride_fn(stage_idx, block_idx)
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
downsample_tuple[1] : downsample_tuple[2]
# DropBlock, DropPath both take in rates based on a linear scaling schedule
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
drop_path = DropPath(pathschedule[schedule_idx])
drop_block = DropBlock(blockschedule[schedule_idx])
block = basicblock(inplanes, planes; stride, reduction_factor, activation,
norm_layer, prenorm, attn_fn, drop_path, drop_block)
downsample = downsample_fn(inplanes, planes * expansion; stride)
# inplanes increases by expansion after each block
inplanes = planes * expansion
return block, downsample
return get_layers

resnet_planes(stage_idx::Integer) = 64 * 2^(stage_idx - 1)

function configure_resnet_block(block_template, expansion, block_repeats::Vector{<:Integer};
stride_fn = get_stride, plane_fn = resnet_planes,
downsample_templates::NTuple{2, Any},
inplanes::Integer = 64,
drop_path_rate = 0.0, drop_block_rate = 0.0, kwargs...)
function bottleneck_builder(block_repeats::Vector{<:Integer}; inplanes::Integer = 64,
cardinality::Integer = 1, base_width::Integer = 64,
reduction_factor::Integer = 1, expansion::Integer = 4,
norm_layer = BatchNorm, prenorm::Bool = false,
activation = relu, attn_fn = planes -> identity,
drop_block_rate = 0.0, drop_path_rate = 0.0,
stride_fn = get_stride, planes_fn = resnet_planes,
downsample_tuple = (downsample_conv, downsample_identity))
pathschedule = linear_scheduler(drop_path_rate; depth = sum(block_repeats))
blockschedule = linear_scheduler(drop_block_rate; depth = sum(block_repeats))
# closure over `idxs`
function get_layers(stage_idx::Integer, block_idx::Integer)
planes = plane_fn(stage_idx)
planes = planes_fn(stage_idx)
# `get_stride` is a callback that the user can tweak to change the stride of the
# blocks. It defaults to the standard behaviour as in the paper
stride = stride_fn(stage_idx, block_idx)
downsample_template = (stride != 1 || inplanes != planes * expansion) ?
downsample_templates[1] : downsample_templates[2]
downsample_fn = (stride != 1 || inplanes != planes * expansion) ?
downsample_tuple[1] : downsample_tuple[2]
# DropBlock, DropPath both take in rates based on a linear scaling schedule
schedule_idx = sum(block_repeats[1:(stage_idx - 1)]) + block_idx
drop_path = DropPath(pathschedule[schedule_idx])
drop_block = DropBlock(blockschedule[schedule_idx])
block = block_template(inplanes, planes; stride, drop_path, drop_block)
downsample = downsample_template(inplanes, planes * expansion; stride)
block = bottleneck(inplanes, planes; stride, cardinality, base_width,
reduction_factor, activation, norm_layer, prenorm,
attn_fn, drop_path, drop_block)
downsample = downsample_fn(inplanes, planes * expansion; stride)
# inplanes increases by expansion after each block
inplanes = (planes * expansion)
inplanes = planes * expansion
return block, downsample
return get_layers
Expand All @@ -283,41 +279,59 @@ function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection)
return Chain(stages...)

function resnet(connection, get_layers, block_repeats::Vector{<:Integer}, stem, classifier)
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
return Chain(Chain(stem, stage_blocks), classifier)
function resnet(block_type::Symbol, block_repeats::Vector{<:Integer};
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity),
cardinality::Integer = 1, base_width::Integer = 64, inplanes::Integer = 64,
reduction_factor::Integer = 1, imsize::Dims{2} = (256, 256),
inchannels::Integer = 3, stem_fn = resnet_stem,
connection = addact, activation = relu, norm_layer = BatchNorm,
prenorm::Bool = false, attn_fn = planes -> identity,
pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false,
drop_block_rate = 0.0, drop_path_rate = 0.0, dropout_rate = 0.0,
nclasses::Integer = 1000)
# Build stem
stem = stem_fn(; inchannels)
# Block builder
if block_type == :basicblock
@assert cardinality==1 "Cardinality must be 1 for `basicblock`"
@assert base_width==64 "Base width must be 64 for `basicblock`"
get_layers = basicblock_builder(block_repeats; inplanes, reduction_factor,
activation, norm_layer, prenorm, attn_fn,
drop_block_rate, drop_path_rate,
stride_fn = get_stride, planes_fn = resnet_planes,
downsample_tuple = downsample_opt)
elseif block_type == :bottleneck
get_layers = bottleneck_builder(block_repeats; inplanes, cardinality, base_width,
reduction_factor, activation, norm_layer,
prenorm, attn_fn, drop_block_rate, drop_path_rate,
stride_fn = get_stride, planes_fn = resnet_planes,
downsample_tuple = downsample_opt)
throw(ArgumentError("Unknown block type $block_type"))
classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate,
pool_layer, use_conv)
return resnet((imsize..., inchannels), stem, connection$activation, get_layers,
block_repeats, classifier_fn)
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...)

function resnet(block_fn, block_repeats::Vector{<:Integer},
downsample_opt::NTuple{2, Any} = (downsample_conv, downsample_identity);
imsize::Dims{2} = (256, 256), inchannels::Integer = 3,
stem = first(resnet_stem(; inchannels)), inplanes::Integer = 64,
connection = addact, activation = relu,
pool_layer = AdaptiveMeanPool((1, 1)), use_conv::Bool = false,
dropout_rate = 0.0, nclasses::Integer = 1000, kwargs...)
# Configure downsample templates
downsample_templates = map(template_builder, downsample_opt)
# Configure block templates
block_template = template_builder(block_fn; kwargs...)
get_layers = configure_resnet_block(block_template, expansion_factor(block_fn),
block_repeats; inplanes, downsample_templates,
function resnet(img_dims, stem, connection, get_layers, block_repeats::Vector{<:Integer},
# Build stages of the ResNet
stage_blocks = resnet_stages(get_layers, block_repeats, connection$activation)
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
backbone = Chain(stem, stage_blocks)
# Build the classifier head
nfeaturemaps = Flux.outputsize(backbone, (imsize..., inchannels); padbatch = true)[3]
classifier = create_classifier(nfeaturemaps, nclasses; dropout_rate, pool_layer,
nfeaturemaps = Flux.outputsize(backbone, img_dims; padbatch = true)[3]
classifier = classifier_fn(nfeaturemaps)
return Chain(backbone, classifier)
function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt], kwargs...)

# block-layer configurations for ResNet-like models
const resnet_configs = Dict(18 => (basicblock, [2, 2, 2, 2]),
34 => (basicblock, [3, 4, 6, 3]),
50 => (bottleneck, [3, 4, 6, 3]),
101 => (bottleneck, [3, 4, 23, 3]),
152 => (bottleneck, [3, 8, 36, 3]))
const resnet_configs = Dict(18 => (:basicblock, [2, 2, 2, 2]),
34 => (:basicblock, [3, 4, 6, 3]),
50 => (:bottleneck, [3, 4, 6, 3]),
101 => (:bottleneck, [3, 4, 23, 3]),
152 => (:bottleneck, [3, 8, 36, 3]))
2 changes: 1 addition & 1 deletion src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ include("normalise.jl")
export prenorm, ChannelLayerNorm

export conv_norm, depthwise_sep_conv_bn, invertedresidual, skip_identity, skip_projection
export conv_norm, depthwise_sep_conv_bn, invertedresidual

export DropBlock, DropPath, droppath_rates
Expand Down

0 comments on commit b143b95

Please sign in to comment.