Skip to content

Commit

Permalink
Make pretrain condition explicit
Browse files Browse the repository at this point in the history
  • Loading branch information
theabhirath committed Jun 27, 2022
1 parent de079bc commit 4fa28d4
Show file tree
Hide file tree
Showing 6 changed files with 33 additions and 10 deletions.
4 changes: 3 additions & 1 deletion src/convnets/alexnet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@ end

function AlexNet(; pretrain = false, nclasses = 1000)
layers = alexnet(; nclasses = nclasses)
pretrain && loadpretrain!(layers, "AlexNet")
if pretrain
loadpretrain!(layers, "AlexNet")
end
return AlexNet(layers)
end

Expand Down
4 changes: 3 additions & 1 deletion src/convnets/densenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ See also [`Metalhead.densenet`](#).
function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000)
@assert config in keys(densenet_config) "`config` must be one out of $(sort(collect(keys(densenet_config))))."
model = DenseNet(densenet_config[config]; nclasses = nclasses)
pretrain && loadpretrain!(model, string("DenseNet", config))
if pretrain
loadpretrain!(model, string("DenseNet", config))
end
return model
end
4 changes: 3 additions & 1 deletion src/convnets/googlenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ end

function GoogLeNet(; pretrain = false, nclasses = 1000)
layers = googlenet(; nclasses = nclasses)
pretrain && loadpretrain!(layers, "GoogLeNet")
if pretrain
loadpretrain!(layers, "GoogLeNet")
end
return GoogLeNet(layers)
end

Expand Down
16 changes: 12 additions & 4 deletions src/convnets/inception.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,9 @@ end

function Inceptionv3(; pretrain = false, nclasses = 1000)
layers = inceptionv3(; nclasses = nclasses)
pretrain && loadpretrain!(layers, "Inceptionv3")
if pretrain
loadpretrain!(layers, "Inceptionv3")
end
return Inceptionv3(layers)
end

Expand Down Expand Up @@ -341,7 +343,9 @@ end

function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000)
layers = inceptionv4(; inchannels, drop_rate, nclasses)
pretrain && loadpretrain!(layers, "Inceptionv4")
if pretrain
loadpretrain!(layers, "Inceptionv4")
end
return Inceptionv4(layers)
end

Expand Down Expand Up @@ -476,7 +480,9 @@ end
function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0,
nclasses = 1000)
layers = inceptionresnetv2(; inchannels, drop_rate, nclasses)
pretrain && loadpretrain!(layers, "InceptionResNetv2")
if pretrain
loadpretrain!(layers, "InceptionResNetv2")
end
return InceptionResNetv2(layers)
end

Expand Down Expand Up @@ -584,7 +590,9 @@ Creates an Xception model.
"""
function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000)
layers = xception(; inchannels, drop_rate, nclasses)
pretrain && loadpretrain!(layers, "xception")
if pretrain
loadpretrain!(layers, "xception")
end
return Xception(layers)
end

Expand Down
11 changes: 9 additions & 2 deletions src/convnets/mobilenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ end
function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false,
nclasses = 1000)
layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses)
pretrain && loadpretrain!(layers, string("MobileNetv1"))
if pretrain
loadpretrain!(layers, string("MobileNetv1"))
end
return MobileNetv1(layers)
end

Expand Down Expand Up @@ -189,6 +191,9 @@ function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false,
nclasses = 1000)
layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses)
pretrain && loadpretrain!(layers, string("MobileNetv2"))
if pretrain
loadpretrain!(layers, string("MobileNetv2"))
end
return MobileNetv2(layers)
end

Expand Down Expand Up @@ -319,7 +324,9 @@ function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels =
max_width = (mode == :large) ? 1280 : 1024
layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width,
nclasses)
pretrain && loadpretrain!(layers, string("MobileNetv3", mode))
if pretrain
loadpretrain!(layers, string("MobileNetv3", mode))
end
return MobileNetv3(layers)
end

Expand Down
4 changes: 3 additions & 1 deletion src/convnets/squeezenet.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ end

function SqueezeNet(; pretrain = false)
layers = squeezenet()
pretrain && loadpretrain!(layers, "SqueezeNet")
if pretrain
loadpretrain!(layers, "SqueezeNet")
end
return SqueezeNet(layers)
end

Expand Down

0 comments on commit 4fa28d4

Please sign in to comment.