Skip to content

Commit

Permalink
More declarative interface for ResNet
Browse files Browse the repository at this point in the history
1. Less keywords for the user to worry about
2. Delete `ResNeXt` just for now
  • Loading branch information
theabhirath committed Jun 28, 2022
1 parent 4fa28d4 commit 7846f8b
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 294 deletions.
111 changes: 57 additions & 54 deletions src/convnets/efficientnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
# Arguments
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
- `n`: number of block repetitions (will be scaled by global depth scaling)
- `k`: kernel size
- `s`: kernel stride
- `e`: expansion ratio
- `i`: block input channels (will be scaled by global width scaling)
- `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
+ `n`: number of block repetitions (will be scaled by global depth scaling)
+ `k`: kernel size
+ `s`: kernel stride
+ `e`: expansion ratio
+ `i`: block input channels (will be scaled by global width scaling)
+ `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
"""
function efficientnet(scalings, block_config;
inchannels = 3, nclasses = 1000, max_width = 1280)
Expand Down Expand Up @@ -64,34 +66,33 @@ end
# i: block input channels
# o: block output channels
const efficientnet_block_configs = [
# (n, k, s, e, i, o)
(1, 3, 1, 1, 32, 16),
(2, 3, 2, 6, 16, 24),
(2, 5, 2, 6, 24, 40),
(3, 3, 2, 6, 40, 80),
(3, 5, 1, 6, 80, 112),
# (n, k, s, e, i, o)
(1, 3, 1, 1, 32, 16),
(2, 3, 2, 6, 16, 24),
(2, 5, 2, 6, 24, 40),
(3, 3, 2, 6, 40, 80),
(3, 5, 1, 6, 80, 112),
(4, 5, 2, 6, 112, 192),
(1, 3, 1, 6, 192, 320)
(1, 3, 1, 6, 192, 320),
]

# w: width scaling
# d: depth scaling
# r: image resolution
const efficientnet_global_configs = Dict(
# ( r, ( w, d))
:b0 => (224, (1.0, 1.0)),
:b1 => (240, (1.0, 1.1)),
:b2 => (260, (1.1, 1.2)),
:b3 => (300, (1.2, 1.4)),
:b4 => (380, (1.4, 1.8)),
:b5 => (456, (1.6, 2.2)),
:b6 => (528, (1.8, 2.6)),
:b7 => (600, (2.0, 3.1)),
:b8 => (672, (2.2, 3.6))
)
# (r, (w, d))
:b0 => (224, (1.0, 1.0)),
:b1 => (240, (1.0, 1.1)),
:b2 => (260, (1.1, 1.2)),
:b3 => (300, (1.2, 1.4)),
:b4 => (380, (1.4, 1.8)),
:b5 => (456, (1.6, 2.2)),
:b6 => (528, (1.8, 2.6)),
:b7 => (600, (2.0, 3.1)),
:b8 => (672, (2.2, 3.6)))

struct EfficientNet
layers::Any
layers::Any
end

"""
Expand All @@ -103,27 +104,29 @@ See also [`efficientnet`](#).
# Arguments
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
- `n`: number of block repetitions (will be scaled by global depth scaling)
- `k`: kernel size
- `s`: kernel stride
- `e`: expansion ratio
- `i`: block input channels (will be scaled by global width scaling)
- `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
+ `n`: number of block repetitions (will be scaled by global depth scaling)
+ `k`: kernel size
+ `s`: kernel stride
+ `e`: expansion ratio
+ `i`: block input channels (will be scaled by global width scaling)
+ `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
"""
function EfficientNet(scalings, block_config;
inchannels = 3, nclasses = 1000, max_width = 1280)
layers = efficientnet(scalings, block_config;
inchannels = inchannels,
nclasses = nclasses,
max_width = max_width)
return EfficientNet(layers)
layers = efficientnet(scalings, block_config;
inchannels = inchannels,
nclasses = nclasses,
max_width = max_width)
return EfficientNet(layers)
end

@functor EfficientNet
Expand All @@ -141,13 +144,13 @@ See also [`efficientnet`](#).
# Arguments
- `name`: name of default configuration
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
- `name`: name of default configuration
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
"""
function EfficientNet(name::Symbol; pretrain = false)
@assert name in keys(efficientnet_global_configs)
"`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))"
"`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))"

model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs)
pretrain && loadpretrain!(model, string("efficientnet-", name))
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/mobilenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)).
function mobilenetv1(width_mult, config;
activation = relu,
inchannels = 3,
fcsize = 1024,
fcsize = 1024,
nclasses = 1000)
layers = []
for (dw, outch, stride, nrepeats) in config
Expand Down
Loading

0 comments on commit 7846f8b

Please sign in to comment.