Skip to content

Commit

Permalink
More declarative interface for ResNet
Browse files Browse the repository at this point in the history
Less keywords for the user to worry about
  • Loading branch information
theabhirath committed Jun 28, 2022
1 parent 4fa28d4 commit 0ef7496
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 280 deletions.
111 changes: 57 additions & 54 deletions src/convnets/efficientnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,21 @@ Create an EfficientNet model ([reference](https://arxiv.org/abs/1905.11946v5)).
# Arguments
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
- `n`: number of block repetitions (will be scaled by global depth scaling)
- `k`: kernel size
- `s`: kernel stride
- `e`: expansion ratio
- `i`: block input channels (will be scaled by global width scaling)
- `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
+ `n`: number of block repetitions (will be scaled by global depth scaling)
+ `k`: kernel size
+ `s`: kernel stride
+ `e`: expansion ratio
+ `i`: block input channels (will be scaled by global width scaling)
+ `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
"""
function efficientnet(scalings, block_config;
inchannels = 3, nclasses = 1000, max_width = 1280)
Expand Down Expand Up @@ -64,34 +66,33 @@ end
# i: block input channels
# o: block output channels
const efficientnet_block_configs = [
# (n, k, s, e, i, o)
(1, 3, 1, 1, 32, 16),
(2, 3, 2, 6, 16, 24),
(2, 5, 2, 6, 24, 40),
(3, 3, 2, 6, 40, 80),
(3, 5, 1, 6, 80, 112),
# (n, k, s, e, i, o)
(1, 3, 1, 1, 32, 16),
(2, 3, 2, 6, 16, 24),
(2, 5, 2, 6, 24, 40),
(3, 3, 2, 6, 40, 80),
(3, 5, 1, 6, 80, 112),
(4, 5, 2, 6, 112, 192),
(1, 3, 1, 6, 192, 320)
(1, 3, 1, 6, 192, 320),
]

# w: width scaling
# d: depth scaling
# r: image resolution
const efficientnet_global_configs = Dict(
# ( r, ( w, d))
:b0 => (224, (1.0, 1.0)),
:b1 => (240, (1.0, 1.1)),
:b2 => (260, (1.1, 1.2)),
:b3 => (300, (1.2, 1.4)),
:b4 => (380, (1.4, 1.8)),
:b5 => (456, (1.6, 2.2)),
:b6 => (528, (1.8, 2.6)),
:b7 => (600, (2.0, 3.1)),
:b8 => (672, (2.2, 3.6))
)
# (r, (w, d))
:b0 => (224, (1.0, 1.0)),
:b1 => (240, (1.0, 1.1)),
:b2 => (260, (1.1, 1.2)),
:b3 => (300, (1.2, 1.4)),
:b4 => (380, (1.4, 1.8)),
:b5 => (456, (1.6, 2.2)),
:b6 => (528, (1.8, 2.6)),
:b7 => (600, (2.0, 3.1)),
:b8 => (672, (2.2, 3.6)))

struct EfficientNet
layers::Any
layers::Any
end

"""
Expand All @@ -103,27 +104,29 @@ See also [`efficientnet`](#).
# Arguments
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
- `n`: number of block repetitions (will be scaled by global depth scaling)
- `k`: kernel size
- `s`: kernel stride
- `e`: expansion ratio
- `i`: block input channels (will be scaled by global width scaling)
- `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
- `scalings`: global width and depth scaling (given as a tuple)
- `block_config`: configuration for each inverted residual block,
given as a vector of tuples with elements:
+ `n`: number of block repetitions (will be scaled by global depth scaling)
+ `k`: kernel size
+ `s`: kernel stride
+ `e`: expansion ratio
+ `i`: block input channels (will be scaled by global width scaling)
+ `o`: block output channels (will be scaled by global width scaling)
- `inchannels`: number of input channels
- `nclasses`: number of output classes
- `max_width`: maximum number of output channels before the fully connected
classification blocks
"""
function EfficientNet(scalings, block_config;
inchannels = 3, nclasses = 1000, max_width = 1280)
layers = efficientnet(scalings, block_config;
inchannels = inchannels,
nclasses = nclasses,
max_width = max_width)
return EfficientNet(layers)
layers = efficientnet(scalings, block_config;
inchannels = inchannels,
nclasses = nclasses,
max_width = max_width)
return EfficientNet(layers)
end

@functor EfficientNet
Expand All @@ -141,13 +144,13 @@ See also [`efficientnet`](#).
# Arguments
- `name`: name of default configuration
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
- `name`: name of default configuration
(can be `:b0`, `:b1`, `:b2`, `:b3`, `:b4`, `:b5`, `:b6`, `:b7`, `:b8`)
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
"""
function EfficientNet(name::Symbol; pretrain = false)
@assert name in keys(efficientnet_global_configs)
"`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))"
"`name` must be one of $(sort(collect(keys(efficientnet_global_configs))))"

model = EfficientNet(efficientnet_global_configs[name][2], efficientnet_block_configs)
pretrain && loadpretrain!(model, string("efficientnet-", name))
Expand Down
2 changes: 1 addition & 1 deletion src/convnets/mobilenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Create a MobileNetv1 model ([reference](https://arxiv.org/abs/1704.04861v1)).
function mobilenetv1(width_mult, config;
activation = relu,
inchannels = 3,
fcsize = 1024,
fcsize = 1024,
nclasses = 1000)
layers = []
for (dw, outch, stride, nrepeats) in config
Expand Down
Loading

0 comments on commit 0ef7496

Please sign in to comment.