Skip to content

Commit

Permalink
Make all config dicts const and capitalise
Browse files Browse the repository at this point in the history
Also misc. formatting
  • Loading branch information
theabhirath committed Jul 29, 2022
1 parent 2aa3459 commit da5321d
Show file tree
Hide file tree
Showing 20 changed files with 78 additions and 76 deletions.
28 changes: 14 additions & 14 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9),
return Chain(Chain(stem..., Chain(blocks)), head)
end

convmixer_configs = Dict(:base => Dict(:planes => 1536, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)),
:small => Dict(:planes => 768, :depth => 32,
:kernel_size => (7, 7),
:patch_size => (7, 7)),
:large => Dict(:planes => 1024, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)))
const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)),
:small => Dict(:planes => 768, :depth => 32,
:kernel_size => (7, 7),
:patch_size => (7, 7)),
:large => Dict(:planes => 1024, :depth => 20,
:kernel_size => (9, 9),
:patch_size => (7, 7)))

"""
ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
Expand All @@ -57,11 +57,11 @@ end
@functor ConvMixer

function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000)
_checkconfig(mode, keys(convmixer_configs))
planes = convmixer_configs[mode][:planes]
depth = convmixer_configs[mode][:depth]
kernel_size = convmixer_configs[mode][:kernel_size]
patch_size = convmixer_configs[mode][:patch_size]
_checkconfig(mode, keys(CONVMIXER_CONFIGS))
planes = CONVMIXER_CONFIGS[mode][:planes]
depth = CONVMIXER_CONFIGS[mode][:depth]
kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size]
patch_size = CONVMIXER_CONFIGS[mode][:patch_size]
layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation,
nclasses)
return ConvMixer(layers)
Expand Down
14 changes: 7 additions & 7 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0
end

# Configurations for ConvNeXt models
convnext_configs = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
:base => ([3, 3, 27, 3], [128, 256, 512, 1024]),
:large => ([3, 3, 27, 3], [192, 384, 768, 1536]),
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))
const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]),
:small => ([3, 3, 27, 3], [96, 192, 384, 768]),
:base => ([3, 3, 27, 3], [128, 256, 512, 1024]),
:large => ([3, 3, 27, 3], [192, 384, 768, 1536]),
:xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048]))

struct ConvNeXt
layers::Any
Expand All @@ -94,8 +94,8 @@ See also [`Metalhead.convnext`](#).
"""
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
nclasses = 1000)
_checkconfig(mode, keys(convnext_configs))
layers = convnext(convnext_configs[mode]...; inchannels, drop_path_rate, λ, nclasses)
_checkconfig(mode, keys(CONVNEXT_CONFIGS))
layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, drop_path_rate, λ, nclasses)
return ConvNeXt(layers)
end

Expand Down
6 changes: 3 additions & 3 deletions src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ end
backbone(m::DenseNet) = m.layers[1]
classifier(m::DenseNet) = m.layers[2]

const densenet_configs = Dict(121 => (6, 12, 24, 16),
const DENSENET_CONFIGS = Dict(121 => (6, 12, 24, 16),
161 => (6, 12, 36, 24),
169 => (6, 12, 32, 32),
201 => (6, 12, 48, 32))
Expand All @@ -160,8 +160,8 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet.
See also [`Metalhead.densenet`](#).
"""
function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000)
_checkconfig(config, keys(densenet_configs))
model = DenseNet(densenet_configs[config]; nclasses = nclasses)
_checkconfig(config, keys(DENSENET_CONFIGS))
model = DenseNet(DENSENET_CONFIGS[config]; nclasses = nclasses)
if pretrain
loadpretrain!(model, string("DenseNet", config))
end
Expand Down
8 changes: 4 additions & 4 deletions src/convnets/efficientnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
# e: expantion ratio
# i: block input channels
# o: block output channels
const efficientnet_block_configs = [
const EFFICIENTNET_BLOCK_CONFIGS = [
# (n, k, s, e, i, o)
(1, 3, 1, 1, 32, 16),
(2, 3, 2, 6, 16, 24),
Expand All @@ -73,7 +73,7 @@ const efficientnet_block_configs = [
# w: width scaling
# d: depth scaling
# r: image resolution
const efficientnet_global_configs = Dict(:b0 => (224, (1.0, 1.0)),
const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)),
:b1 => (240, (1.0, 1.1)),
:b2 => (260, (1.1, 1.2)),
:b3 => (300, (1.2, 1.4)),
Expand Down Expand Up @@ -137,8 +137,8 @@ See also [`efficientnet`](#).
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
"""
function EfficientNet(name::Symbol; pretrain = false)
_checkconfig(name, keys(efficientnet_global_configs))
model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs)
_checkconfig(name, keys(EFFICIENTNET_GLOBAL_CONFIGS))
model = EfficientNet(EFFICIENTNET_GLOBAL_CONFIGS[name][2], EFFICIENTNET_BLOCK_CONFIGS)
pretrain && loadpretrain!(model, string("efficientnet-", name))
return model
end
2 changes: 1 addition & 1 deletion src/convnets/inception/xception.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1,
push!(layers, relu)
append!(layers,
depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false,
use_bn = (false, false)))
use_bn = (false, false)))
push!(layers, BatchNorm(outc))
end
layers = start_with_relu ? layers : layers[2:end]
Expand Down
6 changes: 3 additions & 3 deletions src/convnets/mobilenet/mobilenetv1.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ function mobilenetv1(width_mult, config;
for _ in 1:nrepeats
layer = dw ?
depthwise_sep_conv_norm((3, 3), inchannels, outch, activation;
stride = stride, pad = 1, bias = false) :
stride = stride, pad = 1, bias = false) :
conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1,
bias = false)
append!(layers, layer)
Expand All @@ -45,7 +45,7 @@ function mobilenetv1(width_mult, config;
Dense(inchannels, nclasses)))
end

const mobilenetv1_configs = [
const MOBILENETV1_CONFIGS = [
# dw, c, s, r
(false, 32, 2, 1),
(true, 64, 1, 1),
Expand Down Expand Up @@ -84,7 +84,7 @@ end

function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false,
nclasses = 1000)
layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses)
layers = mobilenetv1(width_mult, MOBILENETV1_CONFIGS; inchannels, nclasses)
if pretrain
loadpretrain!(layers, string("MobileNetv1"))
end
Expand Down
4 changes: 2 additions & 2 deletions src/convnets/mobilenet/mobilenetv2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, ncla
end

# Layer configurations for MobileNetv2
const mobilenetv2_configs = [
const MOBILENETV2_CONFIGS = [
# t, c, n, s, a
(1, 16, 1, 1, relu6),
(6, 24, 2, 2, relu6),
Expand Down Expand Up @@ -83,7 +83,7 @@ See also [`Metalhead.mobilenetv2`](#).
"""
function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false,
nclasses = 1000)
layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses)
layers = mobilenetv2(width_mult, MOBILENETV2_CONFIGS; inchannels, nclasses)
pretrain && loadpretrain!(layers, string("MobileNetv2"))
if pretrain
loadpretrain!(layers, string("MobileNetv2"))
Expand Down
4 changes: 2 additions & 2 deletions src/convnets/mobilenet/mobilenetv3.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla
end

# Configurations for small and large mode for MobileNetv3
mobilenetv3_configs = Dict(:small => [
MOBILENETV3_CONFIGS = Dict(:small => [
# k, t, c, SE, a, s
(3, 1, 16, 4, relu, 2),
(3, 4.5, 24, nothing, relu, 2),
Expand Down Expand Up @@ -115,7 +115,7 @@ function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels =
pretrain = false, nclasses = 1000)
@assert mode in [:large, :small] "`mode` has to be either :large or :small"
max_width = (mode == :large) ? 1280 : 1024
layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width,
layers = mobilenetv3(width_mult, MOBILENETV3_CONFIGS[mode]; inchannels, max_width,
nclasses)
if pretrain
loadpretrain!(layers, string("MobileNetv3", mode))
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/resnets/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...)
end

# block-layer configurations for ResNet-like models
const resnet_configs = Dict(18 => (:basicblock, [2, 2, 2, 2]),
const RESNET_CONFIGS = Dict(18 => (:basicblock, [2, 2, 2, 2]),
34 => (:basicblock, [3, 4, 6, 3]),
50 => (:bottleneck, [3, 4, 6, 3]),
101 => (:bottleneck, [3, 4, 23, 3]),
Expand Down
4 changes: 2 additions & 2 deletions src/convnets/resnets/resnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@ end
@functor ResNet

function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000)
_checkconfig(depth, keys(resnet_configs))
layers = resnet(resnet_configs[depth]...; inchannels, nclasses)
_checkconfig(depth, keys(RESNET_CONFIGS))
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses)
if pretrain
loadpretrain!(layers, string("ResNet", depth))
end
Expand Down
4 changes: 2 additions & 2 deletions src/convnets/resnets/resnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ end

function ResNeXt(depth::Integer; pretrain = false, cardinality = 32,
base_width = 4, inchannels = 3, nclasses = 1000)
_checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end])
layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width)
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width)
if pretrain
loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width))
end
Expand Down
8 changes: 4 additions & 4 deletions src/convnets/resnets/seresnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ end
(m::SEResNet)(x) = m.layers(x)

function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000)
_checkconfig(depth, keys(resnet_configs))
layers = resnet(resnet_configs[depth]...; inchannels, nclasses,
_checkconfig(depth, keys(RESNET_CONFIGS))
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses,
attn_fn = squeeze_excite)
if pretrain
loadpretrain!(layers, string("SEResNet", depth))
Expand Down Expand Up @@ -68,8 +68,8 @@ end

function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4,
inchannels = 3, nclasses = 1000)
_checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end])
layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width,
_checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end])
layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width,
attn_fn = squeeze_excite)
if pretrain
loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width))
Expand Down
8 changes: 4 additions & 4 deletions src/convnets/vgg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dr
return Chain(Chain(conv), class)
end

const vgg_conv_configs = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)],
const VGG_CONV_CONFIGS = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)],
:B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)],
:D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)],
:E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)])

const vgg_configs = Dict(11 => :A,
const VGG_CONFIGS = Dict(11 => :A,
13 => :B,
16 => :D,
19 => :E)
Expand Down Expand Up @@ -153,8 +153,8 @@ See also [`VGG`](#).
- `pretrain`: set to `true` to load pre-trained model weights for ImageNet
"""
function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses = 1000)
_checkconfig(depth, keys(vgg_configs))
model = VGG((224, 224); config = vgg_conv_configs[vgg_configs[depth]],
_checkconfig(depth, keys(VGG_CONFIGS))
model = VGG((224, 224); config = VGG_CONV_CONFIGS[VGG_CONFIGS[depth]],
inchannels = 3,
batchnorm = batchnorm,
nclasses = nclasses,
Expand Down
6 changes: 4 additions & 2 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,12 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1).
- `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#))
"""
function depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu;
norm_layer = BatchNorm, revnorm = false, use_norm = (true, true),
norm_layer = BatchNorm, revnorm = false,
use_norm = (true, true),
stride = 1, kwargs...)
return vcat(conv_norm(kernel_size, inplanes, inplanes, activation;
norm_layerm, revnorm, use_bn = use_bn[1], stride, groups = inplanes,
norm_layerm, revnorm, use_bn = use_bn[1], stride,
groups = inplanes,
kwargs...),
conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, revnorm,
use_bn = use_bn[2]))
Expand Down
8 changes: 4 additions & 4 deletions src/mixers/core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3,
end

# Configurations for MLPMixer models
mixer_configs = Dict(:small => Dict(:depth => 8, :planes => 512),
:base => Dict(:depth => 12, :planes => 768),
:large => Dict(:depth => 24, :planes => 1024),
:huge => Dict(:depth => 32, :planes => 1280))
const MIXER_CONFIGS = Dict(:small => Dict(:depth => 8, :planes => 512),
:base => Dict(:depth => 12, :planes => 768),
:large => Dict(:depth => 24, :planes => 1024),
:huge => Dict(:depth => 32, :planes => 1280))
6 changes: 3 additions & 3 deletions src/mixers/gmlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,9 @@ See also [`Metalhead.mlpmixer`](#).
"""
function gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16),
imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000)
_checkconfig(size, keys(mixer_configs))
depth = mixer_configs[size][:depth]
embedplanes = mixer_configs[size][:planes]
_checkconfig(size, keys(MIXER_CONFIGS))
depth = MIXER_CONFIGS[size][:depth]
embedplanes = MIXER_CONFIGS[size][:planes]
layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block,
patch_size, embedplanes, drop_path_rate, depth, nclasses)
return gMLP(layers)
Expand Down
6 changes: 3 additions & 3 deletions src/mixers/mlpmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ See also [`Metalhead.mlpmixer`](#).
"""
function MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16),
imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000)
_checkconfig(size, keys(mixer_configs))
depth = mixer_configs[size][:depth]
embedplanes = mixer_configs[size][:planes]
_checkconfig(size, keys(MIXER_CONFIGS))
depth = MIXER_CONFIGS[size][:depth]
embedplanes = MIXER_CONFIGS[size][:planes]
layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, drop_path_rate,
nclasses)
return MLPMixer(layers)
Expand Down
6 changes: 3 additions & 3 deletions src/mixers/resmlp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@ See also [`Metalhead.mlpmixer`](#).
"""
function ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16),
imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000)
_checkconfig(size, keys(mixer_configs))
depth = mixer_configs[size][:depth]
embedplanes = mixer_configs[size][:planes]
_checkconfig(size, keys(MIXER_CONFIGS))
depth = MIXER_CONFIGS[size][:depth]
embedplanes = MIXER_CONFIGS[size][:planes]
layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes,
drop_path_rate, depth, nclasses)
return ResMLP(layers)
Expand Down
22 changes: 11 additions & 11 deletions src/vit-based/vit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,15 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} =
Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast)))
end

vit_configs = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3),
:small => (depth = 12, embedplanes = 384, nheads = 6),
:base => (depth = 12, embedplanes = 768, nheads = 12),
:large => (depth = 24, embedplanes = 1024, nheads = 16),
:huge => (depth = 32, embedplanes = 1280, nheads = 16),
:giant => (depth = 40, embedplanes = 1408, nheads = 16,
mlp_ratio = 48 // 11),
:gigantic => (depth = 48, embedplanes = 1664, nheads = 16,
mlp_ratio = 64 // 13))
const VIT_CONFIGS = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3),
:small => (depth = 12, embedplanes = 384, nheads = 6),
:base => (depth = 12, embedplanes = 768, nheads = 12),
:large => (depth = 24, embedplanes = 1024, nheads = 16),
:huge => (depth = 32, embedplanes = 1280, nheads = 16),
:giant => (depth = 40, embedplanes = 1408, nheads = 16,
mlp_ratio = 48 // 11),
:gigantic => (depth = 48, embedplanes = 1664, nheads = 16,
mlp_ratio = 64 // 13))

"""
ViT(mode::Symbol = base; imsize::Dims{2} = (256, 256), inchannels = 3,
Expand Down Expand Up @@ -98,8 +98,8 @@ end

function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3,
patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000)
_checkconfig(mode, keys(vit_configs))
kwargs = vit_configs[mode]
_checkconfig(mode, keys(VIT_CONFIGS))
kwargs = VIT_CONFIGS[mode]
layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...)
return ViT(layers)
end
Expand Down
2 changes: 1 addition & 1 deletion test/convnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ end
@testset "EfficientNet" begin
@testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8]
# preferred image resolution scaling
r = Metalhead.efficientnet_global_configs[name][1]
r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[name][1]
x = rand(Float32, r, r, 3, 1)
m = EfficientNet(name)
@test size(m(x)) == (1000, 1)
Expand Down

0 comments on commit da5321d

Please sign in to comment.