Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Overhaul of ResNet API #174

Merged
merged 67 commits into from
Aug 2, 2022
Merged
Show file tree
Hide file tree
Changes from 16 commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
cd0edef
Add `DropBlock`
theabhirath Jun 16, 2022
271b430
Initial commit for new ResNet API
theabhirath Jun 21, 2022
866dbcc
Cleanup
theabhirath Jun 22, 2022
a038ff8
Get some stuff to work
theabhirath Jun 23, 2022
de079bc
Tweaks - I
theabhirath Jun 23, 2022
4fa28d4
Make pretrain condition explicit
theabhirath Jun 25, 2022
7846f8b
More declarative interface for ResNet
theabhirath Jun 28, 2022
a1d5ddc
Make `DropBlock` really work
theabhirath Jun 28, 2022
3be1d81
Construct the stem outside and pass it into `resnet`
theabhirath Jun 29, 2022
16cbcd0
Add ResNeXt back
theabhirath Jun 29, 2022
e5294ec
Enable CI for Windows
theabhirath Jun 30, 2022
a439bdf
Add more general implementation of SE layer
theabhirath Jun 29, 2022
441ade8
Tweaks III + Some more docs
theabhirath Jul 1, 2022
5d059f5
Fix `DropBlock` on the GPU
theabhirath Jul 3, 2022
226e96a
Add `SEResNet` and `SEResNeXt`
theabhirath Jul 3, 2022
3a4ffbf
More docs, more tweaks
theabhirath Jul 4, 2022
2f755cf
More aggressive GC
theabhirath Jul 8, 2022
5ba4b84
Tweaks don't stop
theabhirath Jul 9, 2022
aaf2abb
Reorganisation and formatting
theabhirath Jul 9, 2022
326f36c
Refactor shortcut connections
theabhirath Jul 9, 2022
4e01443
Generalise `resnet` further
theabhirath Jul 10, 2022
e8d3488
Documentation
theabhirath Jul 10, 2022
92ed4fa
Add classifier and backbone methods
theabhirath Jul 12, 2022
96a7d31
Refactor of resnet core
theabhirath Jul 17, 2022
9540299
Add `DropBlock`
theabhirath Jun 16, 2022
588d703
Initial commit for new ResNet API
theabhirath Jun 21, 2022
2a5d0cc
Cleanup
theabhirath Jun 22, 2022
07c1e95
Get some stuff to work
theabhirath Jun 23, 2022
2e88201
Tweaks - I
theabhirath Jun 23, 2022
01eaa8b
Make pretrain condition explicit
theabhirath Jun 25, 2022
546b131
More declarative interface for ResNet
theabhirath Jun 28, 2022
3f45f27
Make `DropBlock` really work
theabhirath Jun 28, 2022
f373f45
Construct the stem outside and pass it into `resnet`
theabhirath Jun 29, 2022
51d0757
Add ResNeXt back
theabhirath Jun 29, 2022
106f260
Add more general implementation of SE layer
theabhirath Jun 29, 2022
7147309
Tweaks III + Some more docs
theabhirath Jul 1, 2022
7ed20d4
Fix `DropBlock` on the GPU
theabhirath Jul 3, 2022
f0051b7
Add `SEResNet` and `SEResNeXt`
theabhirath Jul 3, 2022
e5d2295
More docs, more tweaks
theabhirath Jul 4, 2022
4a91fc4
More aggressive GC
theabhirath Jul 8, 2022
cf538bb
Tweaks don't stop
theabhirath Jul 9, 2022
5be45ef
Reorganisation and formatting
theabhirath Jul 9, 2022
1e509df
Refactor shortcut connections
theabhirath Jul 9, 2022
e4930f1
Generalise `resnet` further
theabhirath Jul 10, 2022
80bdcde
Documentation
theabhirath Jul 10, 2022
ab37901
Add classifier and backbone methods
theabhirath Jul 12, 2022
68abbb7
Refactor of resnet core
theabhirath Jul 17, 2022
7ad362b
Refactor of resnet core II
theabhirath Jul 22, 2022
93fb500
Merge branch 'resnet-plus' of https://github.com/theabhirath/Metalhea…
theabhirath Jul 22, 2022
13ed5ac
Allow `prenorm`
theabhirath Jul 22, 2022
6c005d3
Cleanup
theabhirath Jul 23, 2022
bd443f1
Reorganisation
theabhirath Jul 23, 2022
ce1da45
Reorganisation
theabhirath Jul 23, 2022
ed57c8f
Merge branch 'resnet-plus' of https://github.com/theabhirath/Metalhea…
theabhirath Jul 27, 2022
8c9f73f
Remove templating for now
theabhirath Jul 27, 2022
ca53acb
Fix tests, hopefully
theabhirath Jul 28, 2022
54ea529
Revert "Remove templating for now"
theabhirath Jul 28, 2022
541fabd
Merge branch 'master' into resnet-plus
theabhirath Jul 29, 2022
cff07cb
MobileNet tweaks
theabhirath Jul 29, 2022
674b27e
Make templating work again
theabhirath Jul 29, 2022
aa2a9ef
Tests just don't fix themselves
theabhirath Jul 29, 2022
b143b95
Fifth refactor is a charm
theabhirath Jul 29, 2022
fc74aa1
Cleanup - docs and code
theabhirath Jul 29, 2022
99eb25a
Make all config dicts `const` and capitalise
theabhirath Jul 29, 2022
73131bf
Formatting, and some tweaks
theabhirath Jul 30, 2022
73df024
Add WideResNet
theabhirath Jul 30, 2022
72cd4a9
Don't use globals
theabhirath Aug 2, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
os:
- ubuntu-latest
- macOS-latest
# - windows-latest
- windows-latest
arch:
- x64
steps:
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@ version = "0.7.3"
[deps]
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

Expand Down
4 changes: 2 additions & 2 deletions docs/make.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Pkg

Pkg.develop(path = "..")
Pkg.develop(; path = "..")

using Publish
using Artifacts, LazyArtifacts
Expand All @@ -13,5 +13,5 @@ p = Publish.Project(Metalhead)

function build_and_deploy(label)
rm(label; recursive = true, force = true)
deploy(Metalhead; root = "/Metalhead.jl", label = label)
return deploy(Metalhead; root = "/Metalhead.jl", label = label)
end
2 changes: 1 addition & 1 deletion docs/serve.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using Pkg

Pkg.develop(path = "..")
Pkg.develop(; path = "..")

using Revise
using Publish
Expand Down
8 changes: 4 additions & 4 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ include("convnets/alexnet.jl")
include("convnets/vgg.jl")
include("convnets/inception.jl")
include("convnets/googlenet.jl")
include("convnets/resnet.jl")
include("convnets/resnext.jl")
include("convnets/resnets.jl")
include("convnets/densenet.jl")
include("convnets/squeezenet.jl")
include("convnets/mobilenet.jl")
Expand All @@ -44,14 +43,15 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
SEResNet, SEResNeXt,
MLPMixer, ResMLP, gMLP,
ViT,
ConvMixer, ConvNeXt

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet,
for T in (:AlexNet, :VGG, :ResNet, :ResNeXt, :DenseNet, :SEResNet, :SEResNeXt,
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3,
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3, :EfficientNet,
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
@eval Base.show(io::IO, ::MIME"text/plain", model::$T) = _maybe_big_show(io, model)
end
Expand Down
4 changes: 3 additions & 1 deletion src/convnets/alexnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ end

function AlexNet(; pretrain = false, nclasses = 1000)
layers = alexnet(; nclasses = nclasses)
pretrain && loadpretrain!(layers, "AlexNet")
if pretrain
loadpretrain!(layers, "AlexNet")
end
return AlexNet(layers)
end

Expand Down
4 changes: 2 additions & 2 deletions src/convnets/convmixer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ Creates a ConvMixer model.

- `planes`: number of planes in the output of each block
- `depth`: number of layers
- `inchannels`: The number of channels in the input. The default value is 3.
- `inchannels`: The number of channels in the input.
- `kernel_size`: kernel size of the convolutional layers
- `patch_size`: size of the patches
- `activation`: activation function used after the convolutional layers
Expand Down Expand Up @@ -45,7 +45,7 @@ Creates a ConvMixer model.
# Arguments

- `mode`: the mode of the model, either `:base`, `:small` or `:large`
- `inchannels`: The number of channels in the input. The default value is 3.
- `inchannels`: The number of channels in the input.
- `activation`: activation function used after the convolutional layers
- `nclasses`: number of classes in the output
"""
Expand Down
24 changes: 11 additions & 13 deletions src/convnets/convnext.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Creates a single block of ConvNeXt.
([reference](https://arxiv.org/abs/2201.03545))

# Arguments:
# Arguments

- `planes`: number of input channels.
- `drop_path_rate`: Stochastic depth rate.
Expand All @@ -27,7 +27,7 @@ end
Creates the layers for a ConvNeXt model.
([reference](https://arxiv.org/abs/2201.03545))

# Arguments:
# Arguments

- `inchannels`: number of input channels.
- `depths`: list with configuration for depth of each block
Expand All @@ -39,32 +39,29 @@ Creates the layers for a ConvNeXt model.
"""
function convnext(depths, planes; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
nclasses = 1000)
@assert length(depths)==length(planes) "`planes` should have exactly one value for each block"

@assert length(depths) == length(planes)
"`planes` should have exactly one value for each block"
downsample_layers = []
stem = Chain(Conv((4, 4), inchannels => planes[1]; stride = 4),
ChannelLayerNorm(planes[1]; ϵ = 1.0f-6))
ChannelLayerNorm(planes[1]))
push!(downsample_layers, stem)
for m in 1:(length(depths) - 1)
downsample_layer = Chain(ChannelLayerNorm(planes[m]; ϵ = 1.0f-6),
downsample_layer = Chain(ChannelLayerNorm(planes[m]),
Conv((2, 2), planes[m] => planes[m + 1]; stride = 2))
push!(downsample_layers, downsample_layer)
end

stages = []
dp_rates = LinRange{Float32}(0.0, drop_path_rate, sum(depths))
cur = 0
for i in 1:length(depths)
for i in eachindex(depths)
push!(stages, [convnextblock(planes[i], dp_rates[cur + j], λ) for j in 1:depths[i]])
cur += depths[i]
end

backbone = collect(Iterators.flatten(Iterators.flatten(zip(downsample_layers, stages))))
head = Chain(GlobalMeanPool(),
MLUtils.flatten,
LayerNorm(planes[end]),
Dense(planes[end], nclasses))

return Chain(Chain(backbone), head)
end

Expand All @@ -90,9 +87,9 @@ end
Creates a ConvNeXt model.
([reference](https://arxiv.org/abs/2201.03545))

# Arguments:
# Arguments

- `inchannels`: The number of channels in the input. The default value is 3.
- `inchannels`: The number of channels in the input.
- `drop_path_rate`: Stochastic depth rate.
- `λ`: Init value for [LayerScale](https://arxiv.org/abs/2103.17239)
- `nclasses`: number of output classes
Expand All @@ -101,7 +98,8 @@ See also [`Metalhead.convnext`](#).
"""
function ConvNeXt(mode::Symbol = :base; inchannels = 3, drop_path_rate = 0.0, λ = 1.0f-6,
nclasses = 1000)
@assert mode in keys(convnext_configs) "`size` must be one of $(collect(keys(convnext_configs)))"
@assert mode in keys(convnext_configs)
"`size` must be one of $(collect(keys(convnext_configs)))"
depths = convnext_configs[mode][:depths]
planes = convnext_configs[mode][:planes]
layers = convnext(depths, planes; inchannels, drop_path_rate, λ, nclasses)
Expand Down
7 changes: 5 additions & 2 deletions src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,8 @@ Create a DenseNet model
- `reduction`: the factor by which the number of feature maps is scaled across each transition
- `nclasses`: the number of output classes
"""
function densenet(nblocks; growth_rate = 32, reduction = 0.5, nclasses = 1000)
function densenet(nblocks::NTuple{N, <:Integer}; growth_rate = 32, reduction = 0.5,
nclasses = 1000) where {N}
return densenet(2 * growth_rate, [fill(growth_rate, n) for n in nblocks];
reduction = reduction, nclasses = nclasses)
end
Expand Down Expand Up @@ -161,6 +162,8 @@ See also [`Metalhead.densenet`](#).
function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000)
@assert config in keys(densenet_config) "`config` must be one out of $(sort(collect(keys(densenet_config))))."
model = DenseNet(densenet_config[config]; nclasses = nclasses)
pretrain && loadpretrain!(model, string("DenseNet", config))
if pretrain
loadpretrain!(model, string("DenseNet", config))
end
return model
end
111 changes: 57 additions & 54 deletions src/convnets/efficientnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
# Arguments
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
- `n`: number of block repetitions (will be scaled by global depth scaling)
- `k`: kernel size
- `s`: kernel stride
- `e`: expansion ratio
- `i`: block input channels (will be scaled by global width scaling)
- `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
+ `n`: number of block repetitions (will be scaled by global depth scaling)
+ `k`: kernel size
+ `s`: kernel stride
+ `e`: expansion ratio
+ `i`: block input channels (will be scaled by global width scaling)
+ `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
"""
function efficientnet(scalings, block_config;
inchannels = 3, nclasses = 1000, max_width = 1280)
Expand Down Expand Up @@ -64,34 +66,33 @@ end
# i: block input channels
# o: block output channels
const efficientnet_block_configs = [
# (n, k, s, e, i, o)
(1, 3, 1, 1, 32, 16),
(2, 3, 2, 6, 16, 24),
(2, 5, 2, 6, 24, 40),
(3, 3, 2, 6, 40, 80),
(3, 5, 1, 6, 80, 112),
# (n, k, s, e, i, o)
(1, 3, 1, 1, 32, 16),
(2, 3, 2, 6, 16, 24),
(2, 5, 2, 6, 24, 40),
(3, 3, 2, 6, 40, 80),
(3, 5, 1, 6, 80, 112),
(4, 5, 2, 6, 112, 192),
(1, 3, 1, 6, 192, 320)
(1, 3, 1, 6, 192, 320),
]

# w: width scaling
# d: depth scaling
# r: image resolution
const efficientnet_global_configs = Dict(
# ( r, ( w, d))
:b0 => (224, (1.0, 1.0)),
:b1 => (240, (1.0, 1.1)),
:b2 => (260, (1.1, 1.2)),
:b3 => (300, (1.2, 1.4)),
:b4 => (380, (1.4, 1.8)),
:b5 => (456, (1.6, 2.2)),
:b6 => (528, (1.8, 2.6)),
:b7 => (600, (2.0, 3.1)),
:b8 => (672, (2.2, 3.6))
)
# (r, (w, d))
:b0 => (224, (1.0, 1.0)),
:b1 => (240, (1.0, 1.1)),
:b2 => (260, (1.1, 1.2)),
:b3 => (300, (1.2, 1.4)),
:b4 => (380, (1.4, 1.8)),
:b5 => (456, (1.6, 2.2)),
:b6 => (528, (1.8, 2.6)),
:b7 => (600, (2.0, 3.1)),
:b8 => (672, (2.2, 3.6)))

struct EfficientNet
layers::Any
layers::Any
end

"""
Expand All @@ -103,27 +104,29 @@ See also [`efficientnet`](#).
# Arguments
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
- `n`: number of block repetitions (will be scaled by global depth scaling)
- `k`: kernel size
- `s`: kernel stride
- `e`: expansion ratio
- `i`: block input channels (will be scaled by global width scaling)
- `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
+ `n`: number of block repetitions (will be scaled by global depth scaling)
+ `k`: kernel size
+ `s`: kernel stride
+ `e`: expansion ratio
+ `i`: block input channels (will be scaled by global width scaling)
+ `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
"""
function EfficientNet(scalings, block_config;
inchannels = 3, nclasses = 1000, max_width = 1280)
layers = efficientnet(scalings, block_config;
inchannels = inchannels,
nclasses = nclasses,
max_width = max_width)
return EfficientNet(layers)
layers = efficientnet(scalings, block_config;
inchannels = inchannels,
nclasses = nclasses,
max_width = max_width)
return EfficientNet(layers)
end

@functor EfficientNet
Expand All @@ -141,13 +144,13 @@ See also [`efficientnet`](#).
# Arguments
- `name`: name of default configuration
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
- `name`: name of default configuration
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
"""
function EfficientNet(name::Symbol; pretrain = false)
@assert name in keys(efficientnet_global_configs)
"`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))"
"`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))"

model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs)
pretrain && loadpretrain!(model, string("efficientnet-", name))
Expand Down
4 changes: 3 additions & 1 deletion src/convnets/googlenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ end

function GoogLeNet(; pretrain = false, nclasses = 1000)
layers = googlenet(; nclasses = nclasses)
pretrain && loadpretrain!(layers, "GoogLeNet")
if pretrain
loadpretrain!(layers, "GoogLeNet")
end
return GoogLeNet(layers)
end

Expand Down
Loading