Skip to content

Commit

Permalink
Add ResNeXt back
Browse files Browse the repository at this point in the history
Also add tests. A lot of tests
  • Loading branch information
theabhirath committed Jun 29, 2022
1 parent 3be1d81 commit c06d963
Show file tree
Hide file tree
Showing 6 changed files with 69 additions and 46 deletions.
5 changes: 3 additions & 2 deletions src/Metalhead.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using BSON
using Artifacts, LazyArtifacts
using Statistics
using MLUtils
using Random

import Functors

Expand Down Expand Up @@ -38,7 +39,7 @@ include("vit-based/vit.jl")
include("pretrain.jl")

export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, # ResNeXt,
ResNet, ResNet18, ResNet34, ResNet50, ResNet101, ResNet152, ResNeXt,
DenseNet, DenseNet121, DenseNet161, DenseNet169, DenseNet201,
GoogLeNet, Inception3, Inceptionv3, Inceptionv4, InceptionResNetv2, Xception,
SqueezeNet, MobileNetv1, MobileNetv2, MobileNetv3, EfficientNet,
Expand All @@ -47,7 +48,7 @@ export AlexNet, VGG, VGG11, VGG13, VGG16, VGG19,
ConvMixer, ConvNeXt

# use Flux._big_show to pretty print large models
for T in (:AlexNet, :VGG, :DenseNet, :ResNet, # :ResNeXt,
for T in (:AlexNet, :VGG, :DenseNet, :ResNet, :ResNeXt,
:GoogLeNet, :Inceptionv3, :Inceptionv4, :InceptionResNetv2, :Xception,
:SqueezeNet, :MobileNetv1, :MobileNetv2, :MobileNetv3,
:MLPMixer, :ResMLP, :gMLP, :ViT, :ConvMixer, :ConvNeXt)
Expand Down
28 changes: 22 additions & 6 deletions src/convnets/resne(x)t.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,18 +173,18 @@ function _drop_blocks(drop_block_prob = 0.0)
]
end

function resnet(block, layers; nclasses = 1000, inchannels = 3, output_stride = 32,
function resnet(block_fn, layers; nclasses = 1000, inchannels = 3, output_stride = 32,
stem = first(resnet_stem(; inchannels)), inplanes = 64,
downsample_fn = downsample_conv,
drop_rates::NamedTuple = (drop_rate = 0.0, drop_path_rate = 0.0,
drop_block_rate = 0.0),
block_args::NamedTuple = NamedTuple())
# Feature Blocks
channels = [64, 128, 256, 512]
stage_blocks = _make_blocks(block, channels, layers, inplanes;
stage_blocks = _make_blocks(block_fn, channels, layers, inplanes;
output_stride, downsample_fn, drop_rates, block_args)
# Head (Pooling and Classifier)
expansion = expansion_factor(block)
expansion = expansion_factor(block_fn)
num_features = 512 * expansion
classifier = Chain(GlobalMeanPool(), Dropout(drop_rates.drop_rate), MLUtils.flatten,
Dense(num_features, nclasses))
Expand All @@ -201,11 +201,27 @@ struct ResNet
end
@functor ResNet

function ResNet(depth::Integer; pretrain = false, nclasses = 1000, kwargs...)
@assert depth in [18, 34, 50, 101, 152] "Invalid depth. Must be one of [18, 34, 50, 101, 152]"
model = resnet(resnet_config[depth]...; nclasses, kwargs...)
function ResNet(depth::Integer; pretrain = false, nclasses = 1000)
@assert depth in [18, 34, 50, 101, 152]
"Invalid depth. Must be one of [18, 34, 50, 101, 152]"
model = resnet(resnet_config[depth]...; nclasses)
if pretrain
loadpretrain!(model, string("resnet", depth))
end
return model
end

struct ResNeXt
layers::Any
end
@functor ResNeXt

function ResNeXt(depth::Integer; pretrain = false, cardinality = 32, base_width = 4, nclasses = 1000)
@assert depth in [50, 101, 152]
"Invalid depth. Must be one of [50, 101, 152]"
model = resnet(bottleneck, [3, 4, 6, 3]; nclasses, block_args = (; cardinality, base_width))
if pretrain
loadpretrain!(model, string("resnext", depth, "_", cardinality, "x", base_width))
end
return model
end
3 changes: 1 addition & 2 deletions src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@ module Layers

using Flux
using CUDA
using NNlib
using NNlibCUDA
using NNlib, NNlibCUDA
using Functors
using ChainRulesCore
using Statistics
Expand Down
4 changes: 2 additions & 2 deletions src/layers/drop.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ function dropblock(rng::AbstractRNG, x::AbstractArray{T, 4}, drop_block_prob, bl
normalize_scale = convert(T, (length(block_mask) / sum(block_mask) .+ 1e-6))
return x .* block_mask .* normalize_scale
end
dropout_mask(rng::CUDA.RNG, x::CuArray, p; kwargs...) = _dropout_mask(rng, x, p; kwargs...)
function dropblock(rng, x::CuArray, p; kwargs...)
dropoutblock(rng::CUDA.RNG, x::CuArray, p, args...) = dropblock(rng, x, p, args...)
function dropblock(rng, x::CuArray, p, args...)
throw(ArgumentError("x isa CuArray, but rng isa $(typeof(rng)). dropblock only support CUDA.RNG for CuArrays."))
end

Expand Down
18 changes: 0 additions & 18 deletions src/utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,6 @@ function _round_channels(channels, divisor, min_value = divisor)
return (new_channels < 0.9 * channels) ? new_channels + divisor : new_channels
end

"""
addrelu(x, y)
Convenience function for `(x, y) -> @. relu(x + y)`.
Useful as the `connection` argument for [`resnet`](#).
See also [`reluadd`](#).
"""
addrelu(x, y) = @. relu(x + y)

"""
reluadd(x, y)
Convenience function for `(x, y) -> @. relu(x) + relu(y)`.
Useful as the `connection` argument for [`resnet`](#).
See also [`addrelu`](#).
"""
reluadd(x, y) = @. relu(x) + relu(y)

"""
cat_channels(x, y, zs...)
Expand Down
57 changes: 41 additions & 16 deletions test/convnets.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,18 +27,39 @@ GC.safepoint()
GC.gc()

@testset "ResNet" begin
@testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
m = ResNet(sz)
@test size(m(x_256)) == (1000, 1)
## TODO: find a way to port pretrained models to the new ResNet API
# Tests for pretrained ResNets
## TODO: find a way to port pretrained models to the new ResNet API
# @testset "ResNet($sz)" for sz in [18, 34, 50, 101, 152]
# if (ResNet, sz) in PRETRAINED_MODELS
# @test acctest(ResNet(sz, pretrain = true))
# else
# @test_throws ArgumentError ResNet(sz, pretrain = true)
# end
@test gradtest(m, x_256)
GC.safepoint()
GC.gc()
# end

@testset "resnet" begin
@testset for block_fn in [Metalhead.basicblock, Metalhead.bottleneck]
layer_list = [
[2, 2, 2, 2],
[3, 4, 6, 3],
[3, 4, 23, 3],
[3, 8, 36, 3]
]
@testset for layers in layer_list
drop_list = [
(drop_rate = 0.1, drop_path_rate = 0.1, drop_block_rate = 0.1),
(drop_rate = 0.5, drop_path_rate = 0.5, drop_block_rate = 0.5),
(drop_rate = 0.9, drop_path_rate = 0.9, drop_block_rate = 0.9),
]
@testset for drop_rates in drop_list
m = Metalhead.resnet(block_fn, layers; drop_rates)
@test size(m(x_224)) == (1000, 1)
@test gradtest(m, x_224)
GC.safepoint()
GC.gc()
end
end
end
end
end

Expand All @@ -47,16 +68,20 @@ GC.gc()

@testset "ResNeXt" begin
@testset for depth in [50, 101, 152]
m = ResNeXt(depth)
@test size(m(x_224)) == (1000, 1)
if ResNeXt in PRETRAINED_MODELS
@test acctest(ResNeXt(depth, pretrain = true))
else
@test_throws ArgumentError ResNeXt(depth, pretrain = true)
@testset for cardinality in [32, 64]
@testset for base_width in [4, 8]
m = ResNeXt(depth; cardinality, base_width)
@test size(m(x_224)) == (1000, 1)
if string("resnext", depth, "_", cardinality, "x", base_width) in PRETRAINED_MODELS
@test acctest(ResNeXt(depth, pretrain = true))
else
@test_throws ArgumentError ResNeXt(depth, pretrain = true)
end
@test gradtest(m, x_224)
GC.safepoint()
GC.gc()
end
end
@test gradtest(m, x_224)
GC.safepoint()
GC.gc()
end
end

Expand Down

0 comments on commit c06d963

Please sign in to comment.