From 4fa28d42ca6345c69d46cd029a579f715b3070b2 Mon Sep 17 00:00:00 2001 From: Abhirath Anand <74202102+theabhirath@users.noreply.github.com> Date: Sat, 25 Jun 2022 15:42:22 +0530 Subject: [PATCH] Make pretrain condition explicit --- src/convnets/alexnet.jl | 4 +++- src/convnets/densenet.jl | 4 +++- src/convnets/googlenet.jl | 4 +++- src/convnets/inception.jl | 16 ++++++++++++---- src/convnets/mobilenet.jl | 11 +++++++++-- src/convnets/squeezenet.jl | 4 +++- 6 files changed, 33 insertions(+), 10 deletions(-) diff --git a/src/convnets/alexnet.jl b/src/convnets/alexnet.jl index 405272dd2..87f2c288e 100644 --- a/src/convnets/alexnet.jl +++ b/src/convnets/alexnet.jl @@ -49,7 +49,9 @@ end function AlexNet(; pretrain = false, nclasses = 1000) layers = alexnet(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "AlexNet") + if pretrain + loadpretrain!(layers, "AlexNet") + end return AlexNet(layers) end diff --git a/src/convnets/densenet.jl b/src/convnets/densenet.jl index 374909bb1..9da4e08b2 100644 --- a/src/convnets/densenet.jl +++ b/src/convnets/densenet.jl @@ -162,6 +162,8 @@ See also [`Metalhead.densenet`](#). function DenseNet(config::Integer = 121; pretrain = false, nclasses = 1000) @assert config in keys(densenet_config) "`config` must be one out of $(sort(collect(keys(densenet_config))))." model = DenseNet(densenet_config[config]; nclasses = nclasses) - pretrain && loadpretrain!(model, string("DenseNet", config)) + if pretrain + loadpretrain!(model, string("DenseNet", config)) + end return model end diff --git a/src/convnets/googlenet.jl b/src/convnets/googlenet.jl index 318463494..946d0d7f7 100644 --- a/src/convnets/googlenet.jl +++ b/src/convnets/googlenet.jl @@ -86,7 +86,9 @@ end function GoogLeNet(; pretrain = false, nclasses = 1000) layers = googlenet(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "GoogLeNet") + if pretrain + loadpretrain!(layers, "GoogLeNet") + end return GoogLeNet(layers) end diff --git a/src/convnets/inception.jl b/src/convnets/inception.jl index ead229551..ba30fa86f 100644 --- a/src/convnets/inception.jl +++ b/src/convnets/inception.jl @@ -182,7 +182,9 @@ end function Inceptionv3(; pretrain = false, nclasses = 1000) layers = inceptionv3(; nclasses = nclasses) - pretrain && loadpretrain!(layers, "Inceptionv3") + if pretrain + loadpretrain!(layers, "Inceptionv3") + end return Inceptionv3(layers) end @@ -341,7 +343,9 @@ end function Inceptionv4(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = inceptionv4(; inchannels, drop_rate, nclasses) - pretrain && loadpretrain!(layers, "Inceptionv4") + if pretrain + loadpretrain!(layers, "Inceptionv4") + end return Inceptionv4(layers) end @@ -476,7 +480,9 @@ end function InceptionResNetv2(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = inceptionresnetv2(; inchannels, drop_rate, nclasses) - pretrain && loadpretrain!(layers, "InceptionResNetv2") + if pretrain + loadpretrain!(layers, "InceptionResNetv2") + end return InceptionResNetv2(layers) end @@ -584,7 +590,9 @@ Creates an Xception model. """ function Xception(; pretrain = false, inchannels = 3, drop_rate = 0.0, nclasses = 1000) layers = xception(; inchannels, drop_rate, nclasses) - pretrain && loadpretrain!(layers, "xception") + if pretrain + loadpretrain!(layers, "xception") + end return Xception(layers) end diff --git a/src/convnets/mobilenet.jl b/src/convnets/mobilenet.jl index 93eba1c06..b7dfcd6f3 100644 --- a/src/convnets/mobilenet.jl +++ b/src/convnets/mobilenet.jl @@ -90,7 +90,9 @@ end function MobileNetv1(width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) layers = mobilenetv1(width_mult, mobilenetv1_configs; inchannels, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv1")) + if pretrain + loadpretrain!(layers, string("MobileNetv1")) + end return MobileNetv1(layers) end @@ -189,6 +191,9 @@ function MobileNetv2(width_mult::Number = 1; inchannels = 3, pretrain = false, nclasses = 1000) layers = mobilenetv2(width_mult, mobilenetv2_configs; inchannels, nclasses) pretrain && loadpretrain!(layers, string("MobileNetv2")) + if pretrain + loadpretrain!(layers, string("MobileNetv2")) + end return MobileNetv2(layers) end @@ -319,7 +324,9 @@ function MobileNetv3(mode::Symbol = :small, width_mult::Number = 1; inchannels = max_width = (mode == :large) ? 1280 : 1024 layers = mobilenetv3(width_mult, mobilenetv3_configs[mode]; inchannels, max_width, nclasses) - pretrain && loadpretrain!(layers, string("MobileNetv3", mode)) + if pretrain + loadpretrain!(layers, string("MobileNetv3", mode)) + end return MobileNetv3(layers) end diff --git a/src/convnets/squeezenet.jl b/src/convnets/squeezenet.jl index c4de36acc..df458f9ff 100644 --- a/src/convnets/squeezenet.jl +++ b/src/convnets/squeezenet.jl @@ -68,7 +68,9 @@ end function SqueezeNet(; pretrain = false) layers = squeezenet() - pretrain && loadpretrain!(layers, "SqueezeNet") + if pretrain + loadpretrain!(layers, "SqueezeNet") + end return SqueezeNet(layers) end