Skip to content

Commit

Permalink
Cleanup - docs and code
Browse files Browse the repository at this point in the history
Co-Authored-By: Kyle Daruwalla <[email protected]>
  • Loading branch information
theabhirath and darsnack committed Jul 29, 2022
1 parent b143b95 commit 2aa3459
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 90 deletions.
6 changes: 3 additions & 3 deletions src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ Create a Densenet bottleneck layer
function dense_bottleneck(inplanes, outplanes)
inner_channels = 4 * outplanes
return SkipConnection(Chain(conv_norm((1, 1), inplanes, inner_channels; bias = false,
prenorm = true)...,
revnorm = true)...,
conv_norm((3, 3), inner_channels, outplanes; pad = 1,
bias = false, prenorm = true)...),
bias = false, revnorm = true)...),
cat_channels)
end

Expand All @@ -31,7 +31,7 @@ Create a DenseNet transition sequence
- `outplanes`: number of output feature maps
"""
function transition(inplanes, outplanes)
return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, prenorm = true)...,
return Chain(conv_norm((1, 1), inplanes, outplanes; bias = false, revnorm = true)...,
MeanPool((2, 2)))
end

Expand Down
6 changes: 3 additions & 3 deletions src/convnets/inception/xception.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1,
end
push!(layers, relu)
append!(layers,
depthwise_sep_conv_bn((3, 3), inc, outc; pad = 1, bias = false,
depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false,
use_bn = (false, false)))
push!(layers, BatchNorm(outc))
end
Expand Down Expand Up @@ -63,8 +63,8 @@ function xception(; inchannels = 3, dropout_rate = 0.0, nclasses = 1000)
xception_block(256, 728, 2; stride = 2),
[xception_block(728, 728, 3) for _ in 1:8]...,
xception_block(728, 1024, 2; stride = 2, grow_at_start = false),
depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)...,
depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...)
depthwise_sep_conv_norm((3, 3), 1024, 1536; pad = 1)...,
depthwise_sep_conv_norm((3, 3), 1536, 2048; pad = 1)...)
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout_rate),
Dense(2048, nclasses))
return Chain(body, head)
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/mobilenet/mobilenetv1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ function mobilenetv1(width_mult, config;
outch = Int(outch * width_mult)
for _ in 1:nrepeats
layer = dw ?
depthwise_sep_conv_bn((3, 3), inchannels, outch, activation;
depthwise_sep_conv_norm((3, 3), inchannels, outch, activation;
stride = stride, pad = 1, bias = false) :
conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1,
bias = false)
Expand Down
112 changes: 53 additions & 59 deletions src/convnets/resnets/core.jl

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions src/convnets/resnets/seresnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end
function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000)
_checkconfig(depth, keys(resnet_configs))
layers = resnet(resnet_configs[depth]...; inchannels, nclasses,
attn_fn = planes -> squeeze_excite(planes))
attn_fn = squeeze_excite)
if pretrain
loadpretrain!(layers, string("SEResNet", depth))
end
Expand Down Expand Up @@ -70,7 +70,7 @@ function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_widt
inchannels = 3, nclasses = 1000)
_checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end])
layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width,
attn_fn = planes -> squeeze_excite(planes))
attn_fn = squeeze_excite)
if pretrain
loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width))
end
Expand Down
2 changes: 1 addition & 1 deletion src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ include("normalise.jl")
export prenorm, ChannelLayerNorm

include("conv.jl")
export conv_norm, depthwise_sep_conv_bn, invertedresidual
export conv_norm, depthwise_sep_conv_norm, invertedresidual

include("drop.jl")
export DropBlock, DropPath, droppath_rates
Expand Down
36 changes: 18 additions & 18 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu;
norm_layer = BatchNorm, prenorm = false, preact = false, use_bn = true,
norm_layer = BatchNorm, revnorm = false, preact = false, use_bn = true,
stride = 1, pad = 0, dilation = 1, groups = 1, [bias, weight, init])
Create a convolution + batch normalization pair with activation.
Expand All @@ -12,44 +12,44 @@ Create a convolution + batch normalization pair with activation.
- `outplanes`: number of output feature maps
- `activation`: the activation function for the final layer
- `norm_layer`: the normalization layer used
- `prenorm`: set to `true` to place the batch norm before the convolution
- `revnorm`: set to `true` to place the batch norm before the convolution
- `preact`: set to `true` to place the activation function before the batch norm
(only compatible with `prenorm = false`)
(only compatible with `revnorm = false`)
- `use_bn`: set to `false` to disable batch normalization
(only compatible with `prenorm = false` and `preact = false`)
(only compatible with `revnorm = false` and `preact = false`)
- `stride`: stride of the convolution kernel
- `pad`: padding of the convolution kernel
- `dilation`: dilation of the convolution kernel
- `groups`: groups for the convolution kernel
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
"""
function conv_norm(kernel_size, inplanes::Int, outplanes::Int, activation = relu;
norm_layer = BatchNorm, prenorm = false, preact = false, use_bn = true,
norm_layer = BatchNorm, revnorm = false, preact = false, use_bn = true,
kwargs...)
if !use_bn
if (preact || prenorm)
if (preact || revnorm)
throw(ArgumentError("`preact` only supported with `use_bn = true`"))
else
return [Conv(kernel_size, inplanes => outplanes, activation; kwargs...)]
end
end
if prenorm
if revnorm
activations = (conv = activation, bn = identity)
bnplanes = inplanes
else
activations = (conv = identity, bn = activation)
bnplanes = outplanes
end
if preact
if prenorm
throw(ArgumentError("`preact` and `prenorm` cannot be set at the same time"))
if revnorm
throw(ArgumentError("`preact` and `revnorm` cannot be set at the same time"))
else
activations = (conv = activation, bn = identity)
end
end
layers = [Conv(kernel_size, inplanes => outplanes, activations.conv; kwargs...),
norm_layer(bnplanes, activations.bn)]
return prenorm ? reverse(layers) : layers
return revnorm ? reverse(layers) : layers
end

function conv_norm(kernel_size, ch::Pair{<:Integer, <:Integer}, outplanes,
Expand All @@ -60,7 +60,7 @@ end

"""
depthwise_sep_conv_bn(kernel_size, inplanes, outplanes, activation = relu;
prenorm = false, use_bn = (true, true),
revnorm = false, use_bn = (true, true),
stride = 1, pad = 0, dilation = 1, [bias, weight, init])
Create a depthwise separable convolution chain as used in MobileNetv1.
Expand All @@ -79,20 +79,20 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
- `inplanes`: number of input feature maps
- `outplanes`: number of output feature maps
- `activation`: the activation function for the final layer
- `prenorm`: set to `true` to place the batch norm before the convolution
- `use_bn`: a tuple of two booleans to specify whether to use batch normalization for the first and second convolution
- `revnorm`: set to `true` to place the batch norm before the convolution
- `use_bn`: a tuple of two booleans to specify whether to use normalization for the first and second convolution
- `stride`: stride of the first convolution kernel
- `pad`: padding of the first convolution kernel
- `dilation`: dilation of the first convolution kernel
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
"""
function depthwise_sep_conv_bn(kernel_size, inplanes, outplanes, activation = relu;
prenorm = false, use_bn = (true, true),
stride = 1, kwargs...)
function depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu;
norm_layer = BatchNorm, revnorm = false, use_norm = (true, true),
stride = 1, kwargs...)
return vcat(conv_norm(kernel_size, inplanes, inplanes, activation;
prenorm, use_bn = use_bn[1], stride, groups = inplanes,
norm_layerm, revnorm, use_bn = use_bn[1], stride, groups = inplanes,
kwargs...),
conv_norm((1, 1), inplanes, outplanes, activation; prenorm,
conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, revnorm,
use_bn = use_bn[2]))
end

Expand Down
3 changes: 2 additions & 1 deletion src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ produce a single output. Note that this is equivalent to
- `output_size`: The size of the output after pooling.
- `connection`: The connection type to use.
"""
function AdaptiveMeanMaxPool(output_size = (1, 1); connection = +)
function AdaptiveMeanMaxPool(connection, output_size = (1, 1))
return Parallel(connection, AdaptiveMeanPool(output_size), AdaptiveMaxPool(output_size))
end
AdaptiveMeanMaxPool(output_size::Tuple = (1, 1)) = AdaptiveMeanMaxPool(+, output_size)
6 changes: 4 additions & 2 deletions test/convnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@ end
@testset "ResNet" begin
# Tests for pretrained ResNets
## TODO: find a way to port pretrained models to the new ResNet API
# @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
@testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
m = ResNet(sz)
@test size(m(x_224)) == (1000, 1)
# if (ResNet, sz) in PRETRAINED_MODELS
# @test acctest(ResNet(sz, pretrain = true))
# else
# @test_throws ArgumentError ResNet(sz, pretrain = true)
# end
# end
end

@testset "resnet" begin
@testset for block_fn in [:basicblock, :bottleneck]
Expand Down

0 comments on commit 2aa3459

Please sign in to comment.