From da5321dbf8a1b25f6acdb1036faf989c93d9637c Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Fri, 29 Jul 2022 22:34:02 +0530 Subject: [PATCH] Make all config dicts `const` and capitalise Also misc. formatting --- src/convnets/convmixer.jl | 28 +++++++++++++-------------- src/convnets/convnext.jl | 14 +++++++------- src/convnets/densenet.jl | 6 +++--- src/convnets/efficientnet.jl | 8 ++++---- src/convnets/inception/xception.jl | 2 +- src/convnets/mobilenet/mobilenetv1.jl | 6 +++--- src/convnets/mobilenet/mobilenetv2.jl | 4 ++-- src/convnets/mobilenet/mobilenetv3.jl | 4 ++-- src/convnets/resnets/core.jl | 2 +- src/convnets/resnets/resnet.jl | 4 ++-- src/convnets/resnets/resnext.jl | 4 ++-- src/convnets/resnets/seresnet.jl | 8 ++++---- src/convnets/vgg.jl | 8 ++++---- src/layers/conv.jl | 6 ++++-- src/mixers/core.jl | 8 ++++---- src/mixers/gmlp.jl | 6 +++--- src/mixers/mlpmixer.jl | 6 +++--- src/mixers/resmlp.jl | 6 +++--- src/vit-based/vit.jl | 22 ++++++++++----------- test/convnets.jl | 2 +- 20 files changed, 78 insertions(+), 76 deletions(-) diff --git a/src/convnets/convmixer.jl b/src/convnets/convmixer.jl index 6547ba4fb..aa3d144d2 100644 --- a/src/convnets/convmixer.jl +++ b/src/convnets/convmixer.jl @@ -28,15 +28,15 @@ function convmixer(planes, depth; inchannels = 3, kernel_size = (9, 9), return Chain(Chain(stem..., Chain(blocks)), head) end -convmixer_configs = Dict(:base => Dict(:planes => 1536, :depth => 20, - :kernel_size => (9, 9), - :patch_size => (7, 7)), - :small => Dict(:planes => 768, :depth => 32, - :kernel_size => (7, 7), - :patch_size => (7, 7)), - :large => Dict(:planes => 1024, :depth => 20, - :kernel_size => (9, 9), - :patch_size => (7, 7))) +const CONVMIXER_CONFIGS = Dict(:base => Dict(:planes => 1536, :depth => 20, + :kernel_size => (9, 9), + :patch_size => (7, 7)), + :small => Dict(:planes => 768, :depth => 32, + :kernel_size => (7, 7), + :patch_size => (7, 7)), + :large => Dict(:planes => 1024, :depth => 20, + :kernel_size => (9, 9), + :patch_size => (7, 7))) """ ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000) @@ -57,11 +57,11 @@ end @functor ConvMixer function ConvMixer(mode::Symbol = :base; inchannels = 3, activation = gelu, nclasses = 1000) - _checkconfig(mode, keys(convmixer_configs)) - planes = convmixer_configs[mode][:planes] - depth = convmixer_configs[mode][:depth] - kernel_size = convmixer_configs[mode][:kernel_size] - patch_size = convmixer_configs[mode][:patch_size] + _checkconfig(mode, keys(CONVMIXER_CONFIGS)) + planes = CONVMIXER_CONFIGS[mode][:planes] + depth = CONVMIXER_CONFIGS[mode][:depth] + kernel_size = CONVMIXER_CONFIGS[mode][:kernel_size] + patch_size = CONVMIXER_CONFIGS[mode][:patch_size] layers = convmixer(planes, depth; inchannels, kernel_size, patch_size, activation, nclasses) return ConvMixer(layers) diff --git a/src/convnets/convnext.jl b/src/convnets/convnext.jl index 052192fec..e6ccee16a 100644 --- a/src/convnets/convnext.jl +++ b/src/convnets/convnext.jl @@ -66,11 +66,11 @@ function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0 end # Configurations for ConvNeXt models -convnext_configs = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]), - :small => ([3, 3, 27, 3], [96, 192, 384, 768]), - :base => ([3, 3, 27, 3], [128, 256, 512, 1024]), - :large => ([3, 3, 27, 3], [192, 384, 768, 1536]), - :xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048])) +const CONVNEXT_CONFIGS = Dict(:tiny => ([3, 3, 9, 3], [96, 192, 384, 768]), + :small => ([3, 3, 27, 3], [96, 192, 384, 768]), + :base => ([3, 3, 27, 3], [128, 256, 512, 1024]), + :large => ([3, 3, 27, 3], [192, 384, 768, 1536]), + :xlarge => ([3, 3, 27, 3], [256, 512, 1024, 2048])) struct ConvNeXt layers::Any @@ -94,8 +94,8 @@ See also [`Metalhead.convnext`](#). """ function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6, nclasses = 1000) - _checkconfig(mode, keys(convnext_configs)) - layers = convnext(convnext_configs[mode]...; inchannels, drop_path_rate, λ, nclasses) + _checkconfig(mode, keys(CONVNEXT_CONFIGS)) + layers = convnext(CONVNEXT_CONFIGS[mode]...; inchannels, drop_path_rate, λ, nclasses) return ConvNeXt(layers) end diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index c41b4028b..332b5551f 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -140,7 +140,7 @@ end backbone(m::DenseNet) = m.layers[1] classifier(m::DenseNet) = m.layers[2] -const densenet_configs = Dict(121 => (6, 12, 24, 16), +const DENSENET_CONFIGS = Dict(121 => (6, 12, 24, 16), 161 => (6, 12, 36, 24), 169 => (6, 12, 32, 32), 201 => (6, 12, 48, 32)) @@ -160,8 +160,8 @@ Set `pretrain = true` to load the model with pre-trained weights for ImageNet. See also [`Metalhead.densenet`](#). """ function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000) - _checkconfig(config, keys(densenet_configs)) - model = DenseNet(densenet_configs[config]; nclasses = nclasses) + _checkconfig(config, keys(DENSENET_CONFIGS)) + model = DenseNet(DENSENET_CONFIGS[config]; nclasses = nclasses) if pretrain loadpretrain!(model, string("DenseNet", config)) end diff --git a/src/convnets/efficientnet.jl b/src/convnets/efficientnet.jl index 122fd512a..4321e9443 100644 --- a/src/convnets/efficientnet.jl +++ b/src/convnets/efficientnet.jl @@ -59,7 +59,7 @@ end # e: expantion ratio # i: block input channels # o: block output channels -const efficientnet_block_configs = [ +const EFFICIENTNET_BLOCK_CONFIGS = [ # (n, k, s, e, i, o) (1, 3, 1, 1, 32, 16), (2, 3, 2, 6, 16, 24), @@ -73,7 +73,7 @@ const efficientnet_block_configs = [ # w: width scaling # d: depth scaling # r: image resolution -const efficientnet_global_configs = Dict(:b0 => (224, (1.0, 1.0)), +const EFFICIENTNET_GLOBAL_CONFIGS = Dict(:b0 => (224, (1.0, 1.0)), :b1 => (240, (1.0, 1.1)), :b2 => (260, (1.1, 1.2)), :b3 => (300, (1.2, 1.4)), @@ -137,8 +137,8 @@ See also [`efficientnet`](#). - `pretrain`: set to `true` to load the pre-trained weights for ImageNet """ function EfficientNet(name::Symbol; pretrain = false) - _checkconfig(name, keys(efficientnet_global_configs)) - model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs) + _checkconfig(name, keys(EFFICIENTNET_GLOBAL_CONFIGS)) + model = EfficientNet(EFFICIENTNET_GLOBAL_CONFIGS[name][2], EFFICIENTNET_BLOCK_CONFIGS) pretrain && loadpretrain!(model, string("efficientnet-", name)) return model end diff --git a/src/convnets/inception/xception.jl b/src/convnets/inception/xception.jl index a585aadd4..fe04ef2db 100644 --- a/src/convnets/inception/xception.jl +++ b/src/convnets/inception/xception.jl @@ -35,7 +35,7 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1, push!(layers, relu) append!(layers, depthwise_sep_conv_norm((3, 3), inc, outc; pad = 1, bias = false, - use_bn = (false, false))) + use_bn = (false, false))) push!(layers, BatchNorm(outc)) end layers = start_with_relu ? layers : layers[2:end] diff --git a/src/convnets/mobilenet/mobilenetv1.jl b/src/convnets/mobilenet/mobilenetv1.jl index 22beaf86f..fe075d5ef 100644 --- a/src/convnets/mobilenet/mobilenetv1.jl +++ b/src/convnets/mobilenet/mobilenetv1.jl @@ -31,7 +31,7 @@ function mobilenetv1(width_mult, config; for _ in 1:nrepeats layer = dw ? depthwise_sep_conv_norm((3, 3), inchannels, outch, activation; - stride = stride, pad = 1, bias = false) : + stride = stride, pad = 1, bias = false) : conv_norm((3, 3), inchannels, outch, activation; stride, pad = 1, bias = false) append!(layers, layer) @@ -45,7 +45,7 @@ function mobilenetv1(width_mult, config; Dense(inchannels, nclasses))) end -const mobilenetv1_configs = [ +const MOBILENETV1_CONFIGS = [ # dw, c, s, r (false, 32, 2, 1), (true, 64, 1, 1), @@ -84,7 +84,7 @@ end function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) - layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses) + layers = mobilenetv1(width_mult, MOBILENETV1_CONFIGS; inchannels, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv1")) end diff --git a/src/convnets/mobilenet/mobilenetv2.jl b/src/convnets/mobilenet/mobilenetv2.jl index 21c017b42..dd9eda012 100644 --- a/src/convnets/mobilenet/mobilenetv2.jl +++ b/src/convnets/mobilenet/mobilenetv2.jl @@ -46,7 +46,7 @@ function mobilenetv2(width_mult, configs; inchannels = 3, max_width = 1280, ncla end # Layer configurations for MobileNetv2 -const mobilenetv2_configs = [ +const MOBILENETV2_CONFIGS = [ # t, c, n, s, a (1, 16, 1, 1, relu6), (6, 24, 2, 2, relu6), @@ -83,7 +83,7 @@ See also [`Metalhead.mobilenetv2`](#). """ function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) - layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses) + layers = mobilenetv2(width_mult, MOBILENETV2_CONFIGS; inchannels, nclasses) pretrain && loadpretrain!(layers, string("MobileNetv2")) if pretrain loadpretrain!(layers, string("MobileNetv2")) diff --git a/src/convnets/mobilenet/mobilenetv3.jl b/src/convnets/mobilenet/mobilenetv3.jl index 6bc444407..5a06f6be5 100644 --- a/src/convnets/mobilenet/mobilenetv3.jl +++ b/src/convnets/mobilenet/mobilenetv3.jl @@ -53,7 +53,7 @@ function mobilenetv3(width_mult, configs; inchannels = 3, max_width = 1024, ncla end # Configurations for small and large mode for MobileNetv3 -mobilenetv3_configs = Dict(:small => [ +MOBILENETV3_CONFIGS = Dict(:small => [ # k, t, c, SE, a, s (3, 1, 16, 4, relu, 2), (3, 4.5, 24, nothing, relu, 2), @@ -115,7 +115,7 @@ function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = pretrain = false, nclasses = 1000) @assert mode in [:large, :small] "`mode` has to be either :large or :small" max_width = (mode == :large) ? 1280 : 1024 - layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width, + layers = mobilenetv3(width_mult, MOBILENETV3_CONFIGS[mode]; inchannels, max_width, nclasses) if pretrain loadpretrain!(layers, string("MobileNetv3", mode)) diff --git a/src/convnets/resnets/core.jl b/src/convnets/resnets/core.jl index ac56e4146..482545e3b 100644 --- a/src/convnets/resnets/core.jl +++ b/src/convnets/resnets/core.jl @@ -324,7 +324,7 @@ function resnet(block_fn, block_repeats, downsample_opt::Symbol = :B; kwargs...) end # block-layer configurations for ResNet-like models -const resnet_configs = Dict(18 => (:basicblock, [2, 2, 2, 2]), +const RESNET_CONFIGS = Dict(18 => (:basicblock, [2, 2, 2, 2]), 34 => (:basicblock, [3, 4, 6, 3]), 50 => (:bottleneck, [3, 4, 6, 3]), 101 => (:bottleneck, [3, 4, 23, 3]), diff --git a/src/convnets/resnets/resnet.jl b/src/convnets/resnets/resnet.jl index 7bebb0873..46c0826c2 100644 --- a/src/convnets/resnets/resnet.jl +++ b/src/convnets/resnets/resnet.jl @@ -23,8 +23,8 @@ end @functor ResNet function ResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - _checkconfig(depth, keys(resnet_configs)) - layers = resnet(resnet_configs[depth]...; inchannels, nclasses) + _checkconfig(depth, keys(RESNET_CONFIGS)) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses) if pretrain loadpretrain!(layers, string("ResNet", depth)) end diff --git a/src/convnets/resnets/resnext.jl b/src/convnets/resnets/resnext.jl index 47e81d44d..8032df5ab 100644 --- a/src/convnets/resnets/resnext.jl +++ b/src/convnets/resnets/resnext.jl @@ -29,8 +29,8 @@ end function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) - _checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end]) - layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width) + _checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end]) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width) if pretrain loadpretrain!(layers, string("ResNeXt", depth, "_", cardinality, "x", base_width)) end diff --git a/src/convnets/resnets/seresnet.jl b/src/convnets/resnets/seresnet.jl index 824f2bbe9..05d842173 100644 --- a/src/convnets/resnets/seresnet.jl +++ b/src/convnets/resnets/seresnet.jl @@ -25,8 +25,8 @@ end (m::SEResNet)(x) = m.layers(x) function SEResNet(depth::Integer; pretrain = false, inchannels = 3, nclasses = 1000) - _checkconfig(depth, keys(resnet_configs)) - layers = resnet(resnet_configs[depth]...; inchannels, nclasses, + _checkconfig(depth, keys(RESNET_CONFIGS)) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, attn_fn = squeeze_excite) if pretrain loadpretrain!(layers, string("SEResNet", depth)) @@ -68,8 +68,8 @@ end function SEResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, inchannels = 3, nclasses = 1000) - _checkconfig(depth, sort(collect(keys(resnet_configs)))[3:end]) - layers = resnet(resnet_configs[depth]...; inchannels, nclasses, cardinality, base_width, + _checkconfig(depth, sort(collect(keys(RESNET_CONFIGS)))[3:end]) + layers = resnet(RESNET_CONFIGS[depth]...; inchannels, nclasses, cardinality, base_width, attn_fn = squeeze_excite) if pretrain loadpretrain!(layers, string("SEResNeXt", depth, "_", cardinality, "x", base_width)) diff --git a/src/convnets/vgg.jl b/src/convnets/vgg.jl index 3a1a8ac10..ccfdd2cff 100644 --- a/src/convnets/vgg.jl +++ b/src/convnets/vgg.jl @@ -99,12 +99,12 @@ function vgg(imsize; config, inchannels, batchnorm = false, nclasses, fcsize, dr return Chain(Chain(conv), class) end -const vgg_conv_configs = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)], +const VGG_CONV_CONFIGS = Dict(:A => [(64, 1), (128, 1), (256, 2), (512, 2), (512, 2)], :B => [(64, 2), (128, 2), (256, 2), (512, 2), (512, 2)], :D => [(64, 2), (128, 2), (256, 3), (512, 3), (512, 3)], :E => [(64, 2), (128, 2), (256, 4), (512, 4), (512, 4)]) -const vgg_configs = Dict(11 => :A, +const VGG_CONFIGS = Dict(11 => :A, 13 => :B, 16 => :D, 19 => :E) @@ -153,8 +153,8 @@ See also [`VGG`](#). - `pretrain`: set to `true` to load pre-trained model weights for ImageNet """ function VGG(depth::Integer = 16; pretrain = false, batchnorm = false, nclasses = 1000) - _checkconfig(depth, keys(vgg_configs)) - model = VGG((224, 224); config = vgg_conv_configs[vgg_configs[depth]], + _checkconfig(depth, keys(VGG_CONFIGS)) + model = VGG((224, 224); config = VGG_CONV_CONFIGS[VGG_CONFIGS[depth]], inchannels = 3, batchnorm = batchnorm, nclasses = nclasses, diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 8a195158e..75da01929 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -87,10 +87,12 @@ See Fig. 3 in [reference](https://arxiv.org/abs/1704.04861v1). - `bias`, `weight`, `init`: initialization for the convolution kernel (see [`Flux.Conv`](#)) """ function depthwise_sep_conv_norm(kernel_size, inplanes, outplanes, activation = relu; - norm_layer = BatchNorm, revnorm = false, use_norm = (true, true), + norm_layer = BatchNorm, revnorm = false, + use_norm = (true, true), stride = 1, kwargs...) return vcat(conv_norm(kernel_size, inplanes, inplanes, activation; - norm_layerm, revnorm, use_bn = use_bn[1], stride, groups = inplanes, + norm_layerm, revnorm, use_bn = use_bn[1], stride, + groups = inplanes, kwargs...), conv_norm((1, 1), inplanes, outplanes, activation; norm_layer, revnorm, use_bn = use_bn[2])) diff --git a/src/mixers/core.jl b/src/mixers/core.jl index 6a55f048e..9f9d3b305 100644 --- a/src/mixers/core.jl +++ b/src/mixers/core.jl @@ -37,7 +37,7 @@ function mlpmixer(block, imsize::Dims{2} = (224, 224); inchannels = 3, end # Configurations for MLPMixer models -mixer_configs = Dict(:small => Dict(:depth => 8, :planes => 512), - :base => Dict(:depth => 12, :planes => 768), - :large => Dict(:depth => 24, :planes => 1024), - :huge => Dict(:depth => 32, :planes => 1280)) +const MIXER_CONFIGS = Dict(:small => Dict(:depth => 8, :planes => 512), + :base => Dict(:depth => 12, :planes => 768), + :large => Dict(:depth => 24, :planes => 1024), + :huge => Dict(:depth => 32, :planes => 1280)) diff --git a/src/mixers/gmlp.jl b/src/mixers/gmlp.jl index 4e681e9b4..9ebd2dce3 100644 --- a/src/mixers/gmlp.jl +++ b/src/mixers/gmlp.jl @@ -96,9 +96,9 @@ See also [`Metalhead.mlpmixer`](#). """ function gMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - _checkconfig(size, keys(mixer_configs)) - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] + _checkconfig(size, keys(MIXER_CONFIGS)) + depth = MIXER_CONFIGS[size][:depth] + embedplanes = MIXER_CONFIGS[size][:planes] layers = mlpmixer(spatial_gating_block, imsize; mlp_layer = gated_mlp_block, patch_size, embedplanes, drop_path_rate, depth, nclasses) return gMLP(layers) diff --git a/src/mixers/mlpmixer.jl b/src/mixers/mlpmixer.jl index e3da17a23..7b6d4aa09 100644 --- a/src/mixers/mlpmixer.jl +++ b/src/mixers/mlpmixer.jl @@ -55,9 +55,9 @@ See also [`Metalhead.mlpmixer`](#). """ function MLPMixer(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - _checkconfig(size, keys(mixer_configs)) - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] + _checkconfig(size, keys(MIXER_CONFIGS)) + depth = MIXER_CONFIGS[size][:depth] + embedplanes = MIXER_CONFIGS[size][:planes] layers = mlpmixer(mixerblock, imsize; patch_size, embedplanes, depth, drop_path_rate, nclasses) return MLPMixer(layers) diff --git a/src/mixers/resmlp.jl b/src/mixers/resmlp.jl index 38163702c..17e340310 100644 --- a/src/mixers/resmlp.jl +++ b/src/mixers/resmlp.jl @@ -58,9 +58,9 @@ See also [`Metalhead.mlpmixer`](#). """ function ResMLP(size::Symbol = :base; patch_size::Dims{2} = (16, 16), imsize::Dims{2} = (224, 224), drop_path_rate = 0.0, nclasses = 1000) - _checkconfig(size, keys(mixer_configs)) - depth = mixer_configs[size][:depth] - embedplanes = mixer_configs[size][:planes] + _checkconfig(size, keys(MIXER_CONFIGS)) + depth = MIXER_CONFIGS[size][:depth] + embedplanes = MIXER_CONFIGS[size][:planes] layers = mlpmixer(resmixerblock, imsize; mlp_ratio = 4.0, patch_size, embedplanes, drop_path_rate, depth, nclasses) return ResMLP(layers) diff --git a/src/vit-based/vit.jl b/src/vit-based/vit.jl index 93eba09ee..bcc5d43ba 100644 --- a/src/vit-based/vit.jl +++ b/src/vit-based/vit.jl @@ -62,15 +62,15 @@ function vit(imsize::Dims{2} = (256, 256); inchannels = 3, patch_size::Dims{2} = Chain(LayerNorm(embedplanes), Dense(embedplanes, nclasses, tanh_fast))) end -vit_configs = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3), - :small => (depth = 12, embedplanes = 384, nheads = 6), - :base => (depth = 12, embedplanes = 768, nheads = 12), - :large => (depth = 24, embedplanes = 1024, nheads = 16), - :huge => (depth = 32, embedplanes = 1280, nheads = 16), - :giant => (depth = 40, embedplanes = 1408, nheads = 16, - mlp_ratio = 48 // 11), - :gigantic => (depth = 48, embedplanes = 1664, nheads = 16, - mlp_ratio = 64 // 13)) +const VIT_CONFIGS = Dict(:tiny => (depth = 12, embedplanes = 192, nheads = 3), + :small => (depth = 12, embedplanes = 384, nheads = 6), + :base => (depth = 12, embedplanes = 768, nheads = 12), + :large => (depth = 24, embedplanes = 1024, nheads = 16), + :huge => (depth = 32, embedplanes = 1280, nheads = 16), + :giant => (depth = 40, embedplanes = 1408, nheads = 16, + mlp_ratio = 48 // 11), + :gigantic => (depth = 48, embedplanes = 1664, nheads = 16, + mlp_ratio = 64 // 13)) """ ViT(mode::Symbol = base; imsize::Dims{2} = (256, 256), inchannels = 3, @@ -98,8 +98,8 @@ end function ViT(mode::Symbol = :base; imsize::Dims{2} = (256, 256), inchannels = 3, patch_size::Dims{2} = (16, 16), pool = :class, nclasses = 1000) - _checkconfig(mode, keys(vit_configs)) - kwargs = vit_configs[mode] + _checkconfig(mode, keys(VIT_CONFIGS)) + kwargs = VIT_CONFIGS[mode] layers = vit(imsize; inchannels, patch_size, nclasses, pool, kwargs...) return ViT(layers) end diff --git a/test/convnets.jl b/test/convnets.jl index 258e037b6..5740ed5c6 100644 --- a/test/convnets.jl +++ b/test/convnets.jl @@ -111,7 +111,7 @@ end @testset "EfficientNet" begin @testset "EfficientNet($name)" for name in [:b0, :b1, :b2, :b3, :b4, :b5] #:b6, :b7, :b8] # preferred image resolution scaling - r = Metalhead.efficientnet_global_configs[name][1] + r = Metalhead.EFFICIENTNET_GLOBAL_CONFIGS[name][1] x = rand(Float32, r, r, 3, 1) m = EfficientNet(name) @test size(m(x)) == (1000, 1)