Skip to content

Commit

Permalink
Cleanup - docs and code
Browse files Browse the repository at this point in the history
Co-authored-by: Kyle Daruwalla <[email protected]>
  • Loading branch information
theabhirath and darsnack authored Jul 29, 2022
1 parent b143b95 commit ccb54da
Showing 1 changed file with 13 additions and 19 deletions.
32 changes: 13 additions & 19 deletions src/convnets/resnets/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,19 @@
drop_block = identity, drop_path = identity,
attn_fn = planes -> identity)
Creates a basic ResNet block.
Creates a basic residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
# Arguments
- `inplanes`: number of input feature maps
- `planes`: number of feature maps for the block
- `stride`: the stride of the block
- `reduction_factor`: the reduction factor that the input feature maps are reduced by before the first
convolution.
- `reduction_factor`: the factor by which the input feature maps
are reduced before the first convolution.
- `activation`: the activation function to use.
- `norm_layer`: the normalization layer to use.
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
function and passed in.
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
function and passed in.
- `drop_block`: the drop block layer
- `drop_path`: the drop path layer
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
"""
function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
Expand All @@ -36,7 +34,6 @@ function basicblock(inplanes::Integer, planes::Integer; stride::Integer = 1,
drop_path]
return Chain(filter!(!=(identity), layers)...)
end
expansion_factor(::typeof(basicblock)) = 1

"""
bottleneck(inplanes, planes; stride = 1, cardinality = 1, base_width = 64,
Expand All @@ -45,7 +42,7 @@ expansion_factor(::typeof(basicblock)) = 1
drop_block = identity, drop_path = identity,
attn_fn = planes -> identity)
Creates a bottleneck ResNet block.
Creates a bottleneck residual block (see [reference](https://arxiv.org/abs/1512.03385v1)).
# Arguments
Expand All @@ -58,10 +55,8 @@ Creates a bottleneck ResNet block.
convolution.
- `activation`: the activation function to use.
- `norm_layer`: the normalization layer to use.
- `drop_block`: the drop block layer. This is usually initialised in the `_make_blocks`
function and passed in.
- `drop_path`: the drop path layer. This is usually initialised in the `_make_blocks`
function and passed in.
- `drop_block`: the drop block layer
- `drop_path`: the drop path layer
- `attn_fn`: the attention function to use. See [`squeeze_excite`](#) for an example.
"""
function bottleneck(inplanes::Integer, planes::Integer; stride::Integer,
Expand All @@ -83,7 +78,6 @@ function bottleneck(inplanes::Integer, planes::Integer; stride::Integer,
attn_fn(outplanes), drop_path]
return Chain(filter!(!=(identity), layers)...)
end
expansion_factor(::typeof(bottleneck)) = 4

# Downsample layer using convolutions.
function downsample_conv(inplanes::Integer, outplanes::Integer; stride::Integer = 1,
Expand Down Expand Up @@ -124,7 +118,7 @@ const shortcut_dict = Dict(:A => (downsample_identity, downsample_identity),
:D => (downsample_pool, downsample_identity))

# Stride for each block in the ResNet model
function get_stride(block_idx::Integer, stage_idx::Integer)
function resnet_stride(stage_idx::Integer, block_idx::Integer)
return (stage_idx == 1 || block_idx != 1) ? 1 : 2
end

Expand Down Expand Up @@ -159,8 +153,7 @@ on how to use this function.
shows peformance improvements over the `:deep` stem in some cases.
- `inchannels`: The number of channels in the input.
- `replace_pool`: Whether to replace the default 3x3 max pooling layer with a
3x3 convolution with stride 2 and a normalisation layer.
- `replace_pool`: Set to true to replace the max pooling layers with a 3x3 convolution + normalization with a stride of two.
- `norm_layer`: The normalisation layer used in the stem.
- `activation`: The activation function used in the stem.
"""
Expand Down Expand Up @@ -270,7 +263,7 @@ end
function resnet_stages(get_layers, block_repeats::Vector{<:Integer}, connection)
# Construct each stage
stages = []
for (stage_idx, (num_blocks)) in enumerate(block_repeats)
for (stage_idx, num_blocks) in enumerate(block_repeats)
# Construct the blocks for each stage
blocks = [Parallel(connection, get_layers(stage_idx, block_idx)...)
for block_idx in range(1, num_blocks)]
Expand Down Expand Up @@ -307,6 +300,7 @@ function resnet(block_type::Symbol, block_repeats::Vector{<:Integer};
stride_fn = get_stride, planes_fn = resnet_planes,
downsample_tuple = downsample_opt)
else
# TODO: write better message when we have link to dev docs for resnet
throw(ArgumentError("Unknown block type $block_type"))
end
classifier_fn = nfeatures -> create_classifier(nfeatures, nclasses; dropout_rate,
Expand All @@ -318,7 +312,7 @@ function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
return resnet(block_fn, block_repeats, shortcut_dict[downsample_opt]; kwargs...)
end

function resnet(img_dims, stem, connection, get_layers, block_repeats::Vector{<:Integer},
function resnet(img_dims, stem, get_layers, block_repeats::Vector{<:Integer}, connection,
classifier_fn)
# Build stages of the ResNet
stage_blocks = resnet_stages(get_layers, block_repeats, connection)
Expand Down

0 comments on commit ccb54da

Please sign in to comment.