Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Jun 22, 2022
1 parent afd6f10 commit 71cba4d
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 166 deletions.
49 changes: 26 additions & 23 deletions src/convnets/inception.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,18 +279,18 @@ function inceptionv4_c()
end

"""
inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000)
Create an Inceptionv4 model.
([reference](https://arxiv.org/abs/1602.07261))
# Arguments
- `inchannels`: number of input channels.
- `dropout`: rate of dropout in classifier head.
- `drop_rate`: rate of dropout in classifier head.
- `nclasses`: the number of output classes.
"""
function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
function inceptionv4(; inchannels = 3, drop_rate = 0.0, nclasses = 1000)
body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)...,
conv_bn((3, 3), 32, 32)...,
conv_bn((3, 3), 32, 64; pad = 1)...,
Expand All @@ -313,12 +313,13 @@ function inceptionv4(; inchannels = 3, dropout = 0.0, nclasses = 1000)
inceptionv4_c(),
inceptionv4_c(),
inceptionv4_c())
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses))
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate),
Dense(1536, nclasses))
return Chain(body, head)
end

"""
Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000)
Creates an Inceptionv4 model.
([reference](https://arxiv.org/abs/1602.07261))
Expand All @@ -327,7 +328,7 @@ Creates an Inceptionv4 model.
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
- `inchannels`: number of input channels.
- `dropout`: rate of dropout in classifier head.
- `drop_rate`: rate of dropout in classifier head.
- `nclasses`: the number of output classes.
!!! warning
Expand All @@ -338,7 +339,7 @@ struct Inceptionv4
layers::Any
end

function Inceptionv4(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000)
layers = inceptionv4(; inchannels, dropout, nclasses)
pretrain && loadpretrain!(layers, "Inceptionv4")
return Inceptionv4(layers)
Expand Down Expand Up @@ -419,18 +420,18 @@ function block8(scale = 1.0f0; activation = identity)
end

"""
inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
inceptionresnetv2(; inchannels = 3, drop_rate =0.0, nclasses = 1000)
Creates an InceptionResNetv2 model.
([reference](https://arxiv.org/abs/1602.07261))
# Arguments
- `inchannels`: number of input channels.
- `dropout`: rate of dropout in classifier head.
- `drop_rate`: rate of dropout in classifier head.
- `nclasses`: the number of output classes.
"""
function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
function inceptionresnetv2(; inchannels = 3, drop_rate = 0.0, nclasses = 1000)
body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2)...,
conv_bn((3, 3), 32, 32)...,
conv_bn((3, 3), 32, 64; pad = 1)...,
Expand All @@ -446,12 +447,13 @@ function inceptionresnetv2(; inchannels = 3, dropout = 0.0, nclasses = 1000)
[block8(0.20f0) for _ in 1:9]...,
block8(; activation = relu),
conv_bn((1, 1), 2080, 1536)...)
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(1536, nclasses))
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate),
Dense(1536, nclasses))
return Chain(body, head)
end

"""
InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000)
Creates an InceptionResNetv2 model.
([reference](https://arxiv.org/abs/1602.07261))
Expand All @@ -460,7 +462,7 @@ Creates an InceptionResNetv2 model.
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet
- `inchannels`: number of input channels.
- `dropout`: rate of dropout in classifier head.
- `drop_rate`: rate of dropout in classifier head.
- `nclasses`: the number of output classes.
!!! warning
Expand All @@ -471,9 +473,9 @@ struct InceptionResNetv2
layers::Any
end

function InceptionResNetv2(; pretrain = false, inchannels = 3, dropout = 0.0,
function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0,
nclasses = 1000)
layers = inceptionresnetv2(; inchannels, dropout, nclasses)
layers = inceptionresnetv2(; inchannels, drop_rate, nclasses)
pretrain && loadpretrain!(layers, "InceptionResNetv2")
return InceptionResNetv2(layers)
end
Expand Down Expand Up @@ -533,18 +535,18 @@ function xception_block(inchannels, outchannels, nrepeats; stride = 1,
end

"""
xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
xception(; inchannels = 3, drop_rate =0.0, nclasses = 1000)
Creates an Xception model.
([reference](https://arxiv.org/abs/1610.02357))
# Arguments
- `inchannels`: number of input channels.
- `dropout`: rate of dropout in classifier head.
- `drop_rate`: rate of dropout in classifier head.
- `nclasses`: the number of output classes.
"""
function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
function xception(; inchannels = 3, drop_rate = 0.0, nclasses = 1000)
body = Chain(conv_bn((3, 3), inchannels, 32; stride = 2, bias = false)...,
conv_bn((3, 3), 32, 64; bias = false)...,
xception_block(64, 128, 2; stride = 2, start_with_relu = false),
Expand All @@ -554,7 +556,8 @@ function xception(; inchannels = 3, dropout = 0.0, nclasses = 1000)
xception_block(728, 1024, 2; stride = 2, grow_at_start = false),
depthwise_sep_conv_bn((3, 3), 1024, 1536; pad = 1)...,
depthwise_sep_conv_bn((3, 3), 1536, 2048; pad = 1)...)
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(dropout), Dense(2048, nclasses))
head = Chain(GlobalMeanPool(), MLUtils.flatten, Dropout(drop_rate),
Dense(2048, nclasses))
return Chain(body, head)
end

Expand All @@ -563,7 +566,7 @@ struct Xception
end

"""
Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
Xception(; pretrain = false, inchannels = 3, drop_rate =0.0, nclasses = 1000)
Creates an Xception model.
([reference](https://arxiv.org/abs/1610.02357))
Expand All @@ -572,15 +575,15 @@ Creates an Xception model.
- `pretrain`: set to `true` to load the pre-trained weights for ImageNet.
- `inchannels`: number of input channels.
- `dropout`: rate of dropout in classifier head.
- `drop_rate`: rate of dropout in classifier head.
- `nclasses`: the number of output classes.
!!! warning
`Xception` does not currently support pretrained weights.
"""
function Xception(; pretrain = false, inchannels = 3, dropout = 0.0, nclasses = 1000)
layers = xception(; inchannels, dropout, nclasses)
function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000)
layers = xception(; inchannels, drop_rate, nclasses)
pretrain && loadpretrain!(layers, "xception")
return Xception(layers)
end
Expand Down
Loading

0 comments on commit 71cba4d

Please sign in to comment.