diff --git a/Datasets/CIFAR100/CIFAR100.swift b/Datasets/CIFAR100/CIFAR100.swift new file mode 100644 index 00000000000..3daf2301112 --- /dev/null +++ b/Datasets/CIFAR100/CIFAR100.swift @@ -0,0 +1,173 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Original source: +// "The CIFAR-100 dataset" +// Alex Krizhevsky, Vinod Nair, and Geoffrey Hinton. +// https://www.cs.toronto.edu/~kriz/cifar.html + +import Foundation +import ModelSupport +import TensorFlow + +public struct CIFAR100 { + /// Type of the collection of non-collated batches. + public typealias Batches = Slices>> + /// The type of the training data, represented as a sequence of epochs, which + /// are collection of batches. + public typealias Training = LazyMapSequence< + TrainingEpochs<[(data: [UInt8], label: Int32)], Entropy>, + LazyMapSequence + > + /// The type of the validation data, represented as a collection of batches. + public typealias Validation = LazyMapSequence, LabeledImage> + /// The training epochs. + public let training: Training + /// The validation batches. + public let validation: Validation + + /// Creates an instance with `batchSize`. + /// + /// - Parameter entropy: a source of randomness used to shuffle sample + /// ordering. It will be stored in `self`, so if it is only pseudorandom + /// and has value semantics, the sequence of epochs is deterministic and not + /// dependent on other operations. + public init(batchSize: Int, entropy: Entropy) { + self.init( + batchSize: batchSize, + entropy: entropy, + device: Device.default, + remoteBinaryArchiveLocation: URL( + string: "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz")!, + normalizing: true) + } + + /// Creates an instance with `batchSize` on `device` using `remoteBinaryArchiveLocation`. + /// + /// - Parameters: + /// - entropy: a source of randomness used to shuffle sample ordering. It + /// will be stored in `self`, so if it is only pseudorandom and has value + /// semantics, the sequence of epochs is deterministic and not dependent + /// on other operations. + /// - normalizing: normalizes the batches with the mean and standard deviation + /// of the dataset iff `true`. Default value is `true`. + public init( + batchSize: Int, + entropy: Entropy, + device: Device, + remoteBinaryArchiveLocation: URL, + localStorageDirectory: URL = DatasetUtilities.defaultDirectory + .appendingPathComponent("CIFAR100", isDirectory: true), + normalizing: Bool + ){ + downloadCIFAR100IfNotPresent(from: remoteBinaryArchiveLocation, to: localStorageDirectory) + + // Training data + let trainingSamples = loadCIFAR100TrainingFiles(in: localStorageDirectory) + training = TrainingEpochs(samples: trainingSamples, batchSize: batchSize, entropy: entropy) + .lazy.map { (batches: Batches) -> LazyMapSequence in + return batches.lazy.map{ makeBatch(samples: $0, normalizing: normalizing, device: device) } + } + + // Validation data + let validationSamples = loadCIFAR100TestFiles(in: localStorageDirectory) + validation = validationSamples.inBatches(of: batchSize).lazy.map { + makeBatch(samples: $0, normalizing: normalizing, device: device) + } + } +} + +extension CIFAR100: ImageClassificationData where Entropy == SystemRandomNumberGenerator { + /// Creates an instance with `batchSize`. + public init(batchSize: Int, on: Device) { + self.init(batchSize: batchSize, entropy: SystemRandomNumberGenerator()) + } +} + +func downloadCIFAR100IfNotPresent(from location: URL, to directory: URL) { + let downloadPath = directory.path + let directoryExists = FileManager.default.fileExists(atPath: downloadPath) + let contentsOfDir = try? FileManager.default.contentsOfDirectory(atPath: downloadPath) + let directoryEmpty = (contentsOfDir == nil) || (contentsOfDir!.isEmpty) + + guard !directoryExists || directoryEmpty else { return } + + let _ = DatasetUtilities.downloadResource( + filename: "cifar-100-binary", fileExtension: "tar.gz", + remoteRoot: location.deletingLastPathComponent(), localStorageDirectory: directory) +} + +func loadCIFAR100File(named name: String, in directory: URL) -> [(data: [UInt8], label: Int32)] { + let path = directory.appendingPathComponent("cifar-100-binary/\(name)").path + + + var imageCount = 50000 + guard let fileContents = try? Data(contentsOf: URL(fileURLWithPath: path)) else { + printError("Could not read dataset file: \(name)") + exit(-1) + } + if name.contains("test") { + guard fileContents.count == 307_400_00 else { + printError( + "Dataset file \(name) should have 307_400_00 bytes, instead had \(fileContents.count)") + exit(-1) + } + imageCount = 10000 + } + else { + guard fileContents.count == 153_700_000 else { + printError( + "Dataset file \(name) should have 15370000 bytes, instead had \(fileContents.count)") + exit(-1) + } + } + + var labeledImages: [(data: [UInt8], label: Int32)] = [] + + let imageByteSize = 3074 + for imageIndex in 0.. [(data: [UInt8], label: Int32)] { + return loadCIFAR100File(named: "train.bin", in: localStorageDirectory) +} + +func loadCIFAR100TestFiles(in localStorageDirectory: URL) -> [(data: [UInt8], label: Int32)] { + return loadCIFAR100File(named: "test.bin", in: localStorageDirectory) +} + +fileprivate func makeBatch( + samples: BatchSamples, normalizing: Bool, device: Device +) -> LabeledImage where BatchSamples.Element == (data: [UInt8], label: Int32) { + let bytes = samples.lazy.map(\.data).reduce(into: [], +=) + let images = Tensor(shape: [samples.count, 3, 32, 32], scalars: bytes, on: device) + + var imageTensor = Tensor(images.transposed(permutation: [0, 2, 3, 1])) + imageTensor /= 255.0 + if normalizing { + let mean = Tensor([0.5071, 0.4867, 0.4408], on: device) + let std = Tensor([0.2675, 0.2565, 0.2761], on: device) + imageTensor = (imageTensor - mean) / std + } + + let labels = Tensor(samples.map(\.label)) + return LabeledImage(data: imageTensor, label: labels) +} \ No newline at end of file diff --git a/Datasets/CMakeLists.txt b/Datasets/CMakeLists.txt index 87670dc1325..1fa30328ab3 100644 --- a/Datasets/CMakeLists.txt +++ b/Datasets/CMakeLists.txt @@ -1,5 +1,6 @@ add_library(Datasets CIFAR10/CIFAR10.swift + CIFAR100/CIFAR100.swift DatasetUtilities.swift COCO/COCO.swift COCO/COCODataset.swift diff --git a/Examples/BigTransfer-CIFAR100/CMakeLists.txt b/Examples/BigTransfer-CIFAR100/CMakeLists.txt new file mode 100644 index 00000000000..49ac8b53ea0 --- /dev/null +++ b/Examples/BigTransfer-CIFAR100/CMakeLists.txt @@ -0,0 +1,9 @@ +add_executable(BigTransfer-CIFAR100 + main.swift) +target_link_libraries(BigTransfer-CIFAR100 PRIVATE + Datasets + ImageClassificationModels) + + +install(TARGETS BigTransfer-CIFAR100 + DESTINATION bin) diff --git a/Examples/BigTransfer-CIFAR100/README.md b/Examples/BigTransfer-CIFAR100/README.md new file mode 100644 index 00000000000..3568b2696d2 --- /dev/null +++ b/Examples/BigTransfer-CIFAR100/README.md @@ -0,0 +1,18 @@ +# Big Transfer with CIFAR-100 + +This script illustrates how to train Big Transfer (https://arxiv.org/abs/1912.11370) against the [CIFAR-100 image classification dataset](https://www.cs.toronto.edu/~kriz/cifar.html). + +This model uses a pre-defined rule based on dataset size to determine the optimal parameters for fine tuning using a modified ResnetV2 transfer learning model. + +## Setup + +To begin, you'll need the [latest version of Swift for +TensorFlow](https://github.com/tensorflow/swift/blob/main/Installation.md) +installed. Make sure you've added the correct version of `swift` to your path. + +To train the model, run: + +```sh +cd swift-models +swift run BigTransfer-CIFAR100 +``` diff --git a/Examples/BigTransfer-CIFAR100/main.swift b/Examples/BigTransfer-CIFAR100/main.swift new file mode 100644 index 00000000000..89fad5813f9 --- /dev/null +++ b/Examples/BigTransfer-CIFAR100/main.swift @@ -0,0 +1,260 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Original source: +// "Big Transfer (BiT): General Visual Representation Learning" +// Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, Neil Houlsby. +// https://arxiv.org/abs/1912.11370 + +import Datasets +import ImageClassificationModels +import TensorFlow +import Foundation +import PythonKit + +// let tf = Python.import("tensorflow") +let np = Python.import("numpy") + +// Optional to enable GPU training +// let _ = _ExecutionContext.global +// let device = Device.defaultXLA +let device = Device.default +let modelName = "BiT-M-R50x1" +var knownModels = [String: String]() +let knownDatasetSizes:[String: (Int, Int)] = [ + "cifar10": (32, 32), + "cifar100": (32, 32), + "oxford_iiit_pet": (224, 224), + "oxford_flowers102": (224, 224), + "imagenet2012": (224, 224), +] + +let cifar100TrainingSize = 50000 +let batchSize = 128 + +/// This error indicates that BiT-Hyperrule cannot find the name of the dataset in the +/// knownDatasetSizes dictionary +enum DatasetNotFoundError: Error { + case invalidInput(String) +} + +/// Return relevent ResNet enumerated type based on weights loaded +/// +/// - Parameters: +/// - modelName: the name of the model pulled from the big transfer repository +/// to grab the enumerated type for +/// - Returns: ResNet enumerated type for BigTransfer model +func getModelUnits(modelName: String) -> BigTransfer.Depth { + if modelName.contains("R50") { + return .resNet50 + } + else if modelName.contains("R101") { + return .resNet101 + } + else { + return .resNet152 + } +} + +/// Get updated image resolution based on the specifications in BiT-Hyperrule +/// +/// - Parameters: +/// - originalResolution: the source resolution for the current image dataset +/// - Returns: new resolution for images based on BiT-Hyperrule +func getResolution(originalResolution: (Int, Int)) -> (Int, Int) { + let area = originalResolution.0 * originalResolution.1 + return area < 96*96 ? (160, 128) : (512, 480) +} + +/// Get the source resolution for the current image dataset from the knownDatasetSizes dictionary +/// +/// - Parameters: +/// - datasetName: name of the current dataset you are using +/// - Returns: new resolution for specified dataset +/// - Throws: +/// - DatasetNotFoundError: will throw an error if the dataset cannot be found in knownDatasetSizes dictionary +func getResolutionFromDataset(datasetName: String) throws -> (Int, Int) { + if let resolution = knownDatasetSizes[datasetName] { + return getResolution(originalResolution: resolution) + } + print("Unsupported dataset " + datasetName + ". Add your own here :)") + throw DatasetNotFoundError.invalidInput(datasetName) + +} + +/// Get training mixup parameters based on Bit-Hyperrule specification for dataset sizes +/// +/// - Parameters: +/// - datasetSize: number of images in the current dataset +/// - Returns: mixup alpha based on number of images +func getMixUp(datasetSize: Int) -> Double { + return datasetSize < 20000 ? 0.0 : 0.1 +} + +/// Get the learning rate schedule based on the dataset size +/// +/// - Parameters: +/// - datasetSize: number of images in the current dataset +/// - Returns: learning rate schedule based on the current dataset +func getSchedule(datasetSize: Int) -> Array { + if datasetSize < 20000{ + return [100, 200, 300, 400, 500] + } + else if datasetSize < 500000 { + return [500, 3000, 6000, 9000, 10000] + } + else { + return [500, 6000, 12000, 18000, 20000] + } +} + +/// Get learning rate at the current step given the dataset size and base learning rate +/// +/// - Parameters: +/// - step: current training step +/// - datasetSize: number of images in the dataset +/// - baseLearningRate: starting learning rate to modify +/// - Returns: learning rate at the current step in training +func getLearningRate(step: Int, datasetSize: Int, baseLearningRate: Float = 0.003) -> Float? { + let supports = getSchedule(datasetSize: datasetSize) + // Linear warmup + if step < supports[0] { + return baseLearningRate * Float(step) / Float(supports[0]) + } + // End of training + else if step >= supports.last! { + return nil + } + // Staircase decays by factor of 10 + else { + var baseLearningRate = baseLearningRate + for s in supports[1...] { + if s < step { + baseLearningRate = baseLearningRate / 10.0 + } + } + return baseLearningRate + } +} + +/// Stores the training statistics for the BigTransfer training process which are different than usual +/// because the mixedup labels must be accounted for while running training statistics. +struct BigTransferTrainingStatistics { + var correctGuessCount = Tensor(0, on: Device.default) + var totalGuessCount = Tensor(0, on: Device.default) + var totalLoss = Tensor(0, on: Device.default) + var batches: Int = 0 + var accuracy: Float { + Float(correctGuessCount.scalarized()) / Float(totalGuessCount.scalarized()) * 100 + } + var averageLoss: Float { totalLoss.scalarized() / Float(batches) } + + init(on device: Device = Device.default) { + correctGuessCount = Tensor(0, on: device) + totalGuessCount = Tensor(0, on: device) + totalLoss = Tensor(0, on: device) + } + + mutating func update(logits: Tensor, labels: Tensor, loss: Tensor) { + let correct = logits.argmax(squeezingAxis: 1) .== labels.argmax(squeezingAxis: 1) + correctGuessCount += Tensor(correct).sum() + totalGuessCount += Int32(labels.shape[0]) + totalLoss += loss + batches += 1 + } +} + +let classCount = 100 +var bitModel = BigTransfer(classCount: classCount, depth: getModelUnits(modelName: modelName), modelName: modelName) +let dataset = CIFAR100(batchSize: batchSize, on: Device.default) +var optimizer = SGD(for: bitModel, learningRate: 0.003, momentum: 0.9) +optimizer = SGD(copying: optimizer, to: device) + +print("Beginning training...") +var currStep: Int = 1 +let lrSupports = getSchedule(datasetSize: cifar100TrainingSize) +let scheduleLength = lrSupports.last! +let stepsPerEpoch = cifar100TrainingSize / batchSize +let epochCount = scheduleLength / stepsPerEpoch +let resizeSize = getResolution(originalResolution: (32, 32)) +let mixupAlpha = getMixUp(datasetSize: cifar100TrainingSize) +let beta = np.random.beta(mixupAlpha, mixupAlpha) +for (epoch, batches) in dataset.training.prefix(epochCount).enumerated() { + let start = Date() + var trainStats = BigTransferTrainingStatistics(on: device) + var testStats = BigTransferTrainingStatistics(on: device) + + Context.local.learningPhase = .training + for batch in batches { + if let newLearningRate = getLearningRate(step: currStep, datasetSize: cifar100TrainingSize, baseLearningRate: 0.003) { + optimizer.learningRate = newLearningRate + currStep = currStep + 1 + } + else { + continue + } + var (eagerImages, eagerLabels) = (batch.data, batch.label) + let resized = resize(images: eagerImages, size: (resizeSize.0, resizeSize.1)) + // Future work to change these calls from Python TensorFlow to Swift for Tensorflow + // let cropped = tf.image.random_crop(resized, [batchSize, resizeSize.1, resizeSize.1, 3]) + // let flipped = tf.image.random_flip_left_right(cropped) + // var mixedUp = flipped + var newLabels = Tensor(Tensor(oneHotAtIndices: eagerLabels, depth: classCount)) + //if mixupAlpha > 0.0 { + // var npLabels = newLabels.makeNumpyArray() + // mixedUp = beta * mixedUp + (1 - beta) * tf.reverse(mixedUp, axis: [0]) + // npLabels = beta * npLabels + (1 - beta) * tf.reverse(npLabels, axis: [0]) + // newLabels = Tensor(numpy: npLabels.numpy())! + //} + // eagerImages = Tensor(numpy: mixedUp.numpy())! + // let images = Tensor(copying: eagerImages, to: device) + let images = Tensor(copying: resized, to: device) + let labels = Tensor(copying: newLabels, to: device) + + let 𝛁model = TensorFlow.gradient(at: bitModel) { bitModel -> Tensor in + let ŷ = bitModel(images) + let loss = softmaxCrossEntropy(logits: ŷ, probabilities: labels) + trainStats.update(logits: ŷ, labels: labels, loss: loss) + return loss + } + + optimizer.update(&bitModel, along: 𝛁model) + LazyTensorBarrier() + } + + Context.local.learningPhase = .inference + for batch in dataset.validation { + var (eagerImages, eagerLabels) = (batch.data, batch.label) + let resized = resize(images: eagerImages, size: (resizeSize.0, resizeSize.1)) + let newLabels = Tensor(Tensor(oneHotAtIndices: eagerLabels, depth: classCount)) + let images = Tensor(copying: resized, to: device) + let labels = Tensor(copying: newLabels, to: device) + let ŷ = bitModel(images) + let loss = softmaxCrossEntropy(logits: ŷ, probabilities: labels) + LazyTensorBarrier() + testStats.update(logits: ŷ, labels: labels, loss: loss) + } + + print( + """ + [Epoch \(epoch)] \ + Training Loss: \(String(format: "%.3f", trainStats.averageLoss)), \ + Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \ + (\(String(format: "%.1f", trainStats.accuracy))%), \ + Test Loss: \(String(format: "%.3f", testStats.averageLoss)), \ + Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \ + (\(String(format: "%.1f", testStats.accuracy))%) \ + seconds per epoch: \(String(format: "%.1f", Date().timeIntervalSince(start))) + """) +} \ No newline at end of file diff --git a/Examples/CMakeLists.txt b/Examples/CMakeLists.txt index f64966798e9..b14fa4b2070 100644 --- a/Examples/CMakeLists.txt +++ b/Examples/CMakeLists.txt @@ -1,3 +1,4 @@ +add_subdirectory(BigTransfer-CIFAR100) add_subdirectory(Custom-CIFAR10) add_subdirectory(ResNet-CIFAR10) add_subdirectory(LeNet-MNIST) diff --git a/Models/ImageClassification/BigTransfer.swift b/Models/ImageClassification/BigTransfer.swift new file mode 100644 index 00000000000..16b48bc10e0 --- /dev/null +++ b/Models/ImageClassification/BigTransfer.swift @@ -0,0 +1,370 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Original source: +// "Big Transfer (BiT): General Visual Representation Learning" +// Alexander Kolesnikov, Lucas Beyer, Xiaohua Zhai, Joan Puigcerver, Jessica Yung, Sylvain Gelly, Neil Houlsby. +// https://arxiv.org/abs/1912.11370 + +import Foundation +import TensorFlow +import PythonKit + +let subprocess = Python.import("subprocess") +let np = Python.import("numpy") + + +/// Convenient layer wrapper used to load all of the trained layers from the .npz file downloaded from the +/// BigTransfer weights repository +struct BigTransferNamedLayer { + let name: String + let layer: Tensor +} + +/// Get the necessary padding to maintain the network size specified in the BigTransfer architecture +/// +/// - Parameters: +/// - kernelSize: size n which represents the height and width of the nxn kernel +/// - Returns: the left / top padding and the right / bottom padding necessary to maintain correct output sizes +/// after convolution +func paddingFromKernelSize(kernelSize: Int) -> [(before: Int, after: Int)] { + let padTotal = kernelSize - 1 + let padBeginning = Int(padTotal / 2) + let padEnd = padTotal - padBeginning + let padding = [ + (before: 0, after: 0), + (before: padBeginning, after: padEnd), + (before: padBeginning, after: padEnd), + (before: 0, after: 0)] + return padding +} + +/// Get all of the pre-trained layers from the .npz file into a Swift array to load into the BigTransfer model +/// +/// - Parameters: +/// - modelName: model name that represents the weights to load from the BigTransfer weights repository +/// ("BiT-M-R50x1" for example) +/// - Returns: an array of layers and their associated name in the .npz file downloaded from the weights repository +func getPretrainedWeightsDict(modelName: String) -> Array { + let validTypes = ["BiT-S", "BiT-M"] + let validSizes = [(50, 1), (50, 3), (101, 1), (101, 3), (152, 4)] + let bitURL = "https://storage.googleapis.com/bit_models/" + var knownModels = [String: String]() + + for types in validTypes { + for sizes in validSizes { + let modelString = types + "-R" + String(sizes.0) + "x" + String(sizes.1) + knownModels[modelString] = bitURL + modelString + ".npz" + } + } + + if let modelPath = knownModels[modelName] { + subprocess.call("wget " + modelPath + " .", shell: true) + } + + let weights = np.load("./" + modelName + ".npz") + + var weightsArray = Array() + for param in weights { + weightsArray.append(BigTransferNamedLayer(name: String(param)!, layer: Tensor(numpy: weights[param])!)) + } + return weightsArray +} + +/// A 2D Convolution layer that standardizes the weights before the forward pass. This has been implemented in +/// accordance with the implementation in https://github.com/google-research/big_transfer/blob/49afe42338b62af9fbe18f0258197a33ee578a6b/bit_pytorch/models.py#L25 +public struct StandardizedConv2D: Layer { + public var conv: Conv2D + + public init( + filterShape: (Int, Int, Int, Int), + strides: (Int, Int) = (1, 1), + padding: Padding = .valid, + useBias: Bool = true + ) + { + self.conv = Conv2D( + filterShape: filterShape, + strides: strides, + padding: padding, + useBias: useBias) + } + + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + let axes: Array = [0, 1, 2] + var standardizedConv = conv + standardizedConv.filter = (standardizedConv.filter - standardizedConv.filter.mean(squeezingAxes: axes)) / sqrt((standardizedConv.filter.variance(squeezingAxes: axes) + 1e-16)) + return standardizedConv(input) + } + +} + +/// A standardized convolution and group norm layer as specified in the BigTransfer architecture +public struct ConvGNV2BiT: Layer { + public var conv: StandardizedConv2D + public var norm: GroupNorm + @noDerivative public var isSecond: Bool + + public init( + inFilters: Int, + outFilters: Int, + kernelSize: Int = 1, + stride: Int = 1, + padding: Padding = .valid, + isSecond: Bool = false + ) { + self.conv = StandardizedConv2D( + filterShape: (kernelSize, kernelSize, inFilters, outFilters), + strides: (stride, stride), + padding: padding, + useBias: false) + self.norm = GroupNorm( + offset: Tensor(zeros: [inFilters]), + scale: Tensor(zeros: [inFilters]), + groupCount: 2, + axis: -1, + epsilon: 0.001) + self.isSecond = isSecond + } + + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + var normResult = norm(input) + if self.isSecond { + normResult = normResult.padded(forSizes: paddingFromKernelSize(kernelSize: 3)) + } + let reluResult = relu(normResult) + let convResult = conv(reluResult) + return convResult + } +} + +/// The shortcut in a residual block with standardized convolution and group normalization +public struct ShortcutBiT: Layer { + public var projection: StandardizedConv2D + public var norm: GroupNorm + @noDerivative public let needsProjection: Bool + + public init(inFilters: Int, outFilters: Int, stride: Int) { + needsProjection = (stride > 1 || inFilters != outFilters) + norm = GroupNorm( + offset: Tensor(zeros: [needsProjection ? inFilters : 1]), + scale: Tensor(zeros: [needsProjection ? inFilters : 1]), + groupCount: needsProjection ? 2 : 1, + axis: -1, + epsilon: 0.001) + + projection = StandardizedConv2D( + filterShape: (1, 1, needsProjection ? inFilters : 1, needsProjection ? outFilters : 1), + strides: (stride, stride), + padding: .valid, + useBias: false) + } + + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + var res = input + if needsProjection { + res = norm(res) + res = relu(res) + res = projection(res) + } + return res + } +} + +/// Residual block for BigTransfer with standardized convolution and group normalization layers +public struct ResidualBlockBiT: Layer { + public var shortcut: ShortcutBiT + public var convs: [ConvGNV2BiT] + + public init(inFilters: Int, outFilters: Int, stride: Int, expansion: Int){ + if expansion == 1 { + convs = [ + ConvGNV2BiT(inFilters: inFilters, outFilters: outFilters, kernelSize: 3, stride: stride), + ConvGNV2BiT(inFilters: outFilters, outFilters: outFilters, kernelSize: 3, isSecond: true) + ] + } else { + convs = [ + ConvGNV2BiT(inFilters: inFilters, outFilters: outFilters/4), + ConvGNV2BiT(inFilters: outFilters/4, outFilters: outFilters/4, kernelSize: 3, stride: stride, isSecond: true), + ConvGNV2BiT(inFilters: outFilters/4, outFilters: outFilters) + ] + } + shortcut = ShortcutBiT(inFilters: inFilters, outFilters: outFilters, stride: stride) + } + + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + let convResult = convs.differentiableReduce(input) { $1($0) } + return convResult + shortcut(input) + } +} + +/// An implementation of the BigTransfer architecture with variable sizes +public struct BigTransfer: Layer { + public var inputStem: StandardizedConv2D + public var maxPool: MaxPool2D + public var residualBlocks: [ResidualBlockBiT] = [] + public var groupNorm : GroupNorm + public var flatten = Flatten() + public var classifier: Dense + public var avgPool = GlobalAvgPool2D() + @noDerivative public var finalOutFilter : Int = 0 + + /// Initialize the BigTransfer Model + /// + /// - Parameters: + /// - classCount: the number of output classes + /// - depth: the specified depht of the network based on the various ResNet architectures + /// - inputChannels: the number of input channels for the dataset + /// - stemFilters: the number of filters in the first three convolutions + public init( + classCount: Int, + depth: Depth, + inputChannels: Int = 3, + modelName: String = "BiT-M-R50x1", + loadWeights: Bool = true + ) { + + self.inputStem = StandardizedConv2D(filterShape: (7, 7, 3, 64), strides: (2, 2), padding: .valid, useBias: false) + self.maxPool = MaxPool2D(poolSize: (3, 3), strides: (2, 2), padding: .valid) + let sizes = [64 / depth.expansion, 64, 128, 256, 512] + for (iBlock, nBlocks) in depth.layerBlockSizes.enumerated() { + let (nIn, nOut) = (sizes[iBlock] * depth.expansion, sizes[iBlock+1] * depth.expansion) + for j in 0..( + offset: Tensor(zeros: [self.finalOutFilter]), + scale: Tensor(zeros: [self.finalOutFilter]), + groupCount: 2, + axis: -1, + epsilon: 0.001) + self.classifier = Dense(inputSize: 512 * depth.expansion, outputSize: classCount) + + if loadWeights { + let weightsArray = getPretrainedWeightsDict(modelName: modelName) + + // Load weights from model .npz file into the BigTransfer model + let convs = weightsArray.filter {key in return key.name.contains("/block") && key.name.contains("standardized_conv2d/kernel") && !(key.name.contains("proj"))} + + var k = 0 + for (idx, i) in self.residualBlocks.enumerated() { + for (jdx, _) in i.convs.enumerated() { + assert(self.residualBlocks[idx].convs[jdx].conv.conv.filter.shape == convs[k].layer.shape) + self.residualBlocks[idx].convs[jdx].conv.conv.filter = convs[k].layer + k = k + 1 + } + } + + let projectiveConvs = weightsArray.filter {key in return key.name.contains("/block") && key.name.contains("standardized_conv2d/kernel") && (key.name.contains("proj"))} + var normScale = weightsArray.filter {key in return key.name.contains("unit01/a/group_norm/gamma")} + var normOffset = weightsArray.filter {key in return key.name.contains("unit01/a/group_norm/beta")} + + k = 0 + for (idx, i) in self.residualBlocks.enumerated() { + if (i.shortcut.projection.conv.filter.shape != [1, 1, 1, 1]) + { + assert(self.residualBlocks[idx].shortcut.projection.conv.filter.shape == projectiveConvs[k].layer.shape) + self.residualBlocks[idx].shortcut.projection.conv.filter = projectiveConvs[k].layer + + assert(self.residualBlocks[idx].shortcut.norm.scale.shape == normScale[k].layer.shape) + self.residualBlocks[idx].shortcut.norm.scale = normScale[k].layer + + assert(self.residualBlocks[idx].shortcut.norm.offset.shape == normOffset[k].layer.shape) + self.residualBlocks[idx].shortcut.norm.offset = normOffset[k].layer + k = k + 1 + } + } + + normScale = weightsArray.filter {key in return key.name.contains("gamma")} + + k = 0 + for (idx, i) in self.residualBlocks.enumerated() { + for (jdx, _) in i.convs.enumerated() { + assert(normScale[k].layer.shape == self.residualBlocks[idx].convs[jdx].norm.scale.shape) + self.residualBlocks[idx].convs[jdx].norm.scale = normScale[k].layer + k = k + 1 + } + } + + normOffset = weightsArray.filter {key in return key.name.contains("beta")} + + var l = 0 + for (idx, i) in self.residualBlocks.enumerated() { + for (jdx, _) in i.convs.enumerated() { + assert(normOffset[l].layer.shape == self.residualBlocks[idx].convs[jdx].norm.offset.shape) + self.residualBlocks[idx].convs[jdx].norm.offset = normOffset[l].layer + l = l + 1 + } + } + + assert(self.groupNorm.scale.shape == normScale[k].layer.shape) + self.groupNorm.scale = normScale[k].layer + assert(self.groupNorm.offset.shape == normOffset[l].layer.shape) + self.groupNorm.offset = normOffset[l].layer + + let rootConvs = weightsArray.filter {key in return key.name.contains("root_block")} + assert(self.inputStem.conv.filter.shape == rootConvs[0].layer.shape) + self.inputStem.conv.filter = rootConvs[0].layer + } + } + + @differentiable + public func callAsFunction(_ input: Tensor) -> Tensor { + var paddedInput = input.padded(forSizes: paddingFromKernelSize(kernelSize: 7)) + paddedInput = inputStem(paddedInput).padded(forSizes: paddingFromKernelSize(kernelSize: 3)) + let inputLayer = maxPool(paddedInput) + let blocksReduced = residualBlocks.differentiableReduce(inputLayer) { $1($0) } + let normalized = relu(groupNorm(blocksReduced)) + return normalized.sequenced(through: avgPool, flatten, classifier) + } +} + +extension BigTransfer { + public enum Depth { + case resNet18 + case resNet34 + case resNet50 + case resNet101 + case resNet152 + + var expansion: Int { + switch self { + case .resNet18, .resNet34: return 1 + default: return 4 + } + } + + var layerBlockSizes: [Int] { + switch self { + case .resNet18: return [2, 2, 2, 2] + case .resNet34: return [3, 4, 6, 3] + case .resNet50: return [3, 4, 6, 3] + case .resNet101: return [3, 4, 23, 3] + case .resNet152: return [3, 8, 36, 3] + } + } + } +} diff --git a/Models/ImageClassification/CMakeLists.txt b/Models/ImageClassification/CMakeLists.txt old mode 100755 new mode 100644 index a64f58af49a..95eb61e9bea --- a/Models/ImageClassification/CMakeLists.txt +++ b/Models/ImageClassification/CMakeLists.txt @@ -1,4 +1,5 @@ add_library(ImageClassificationModels + BigTransfer.swift DenseNet121.swift EfficientNet.swift LeNet-5.swift diff --git a/Package.swift b/Package.swift index 11172e935ab..59ee9bcc64e 100644 --- a/Package.swift +++ b/Package.swift @@ -77,6 +77,9 @@ let package = Package( name: "ResNet-CIFAR10", dependencies: ["Datasets", "ImageClassificationModels", "TrainingLoop"], path: "Examples/ResNet-CIFAR10"), + .target(name: "BigTransfer-CIFAR100", + dependencies: ["Datasets", "ImageClassificationModels"], + path: "Examples/BigTransfer-CIFAR100"), .target( name: "Shallow-Water-PDE", dependencies: ["ArgumentParser", "Benchmark", "ModelSupport"], diff --git a/Tests/DatasetsTests/CIFAR100/CIFAR100Tests.swift b/Tests/DatasetsTests/CIFAR100/CIFAR100Tests.swift new file mode 100644 index 00000000000..fe10de77d7d --- /dev/null +++ b/Tests/DatasetsTests/CIFAR100/CIFAR100Tests.swift @@ -0,0 +1,66 @@ +import Datasets +import Foundation +import TensorFlow +import XCTest + +final class CIFAR100Tests: XCTestCase { + func testCreateCIFAR100() { + let dataset = CIFAR100( + batchSize: 1, + entropy: SystemRandomNumberGenerator(), + device: Device.default, + remoteBinaryArchiveLocation: + URL( + string: + "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz" + )!, normalizing: true + ) + verify(dataset) + } + + func verify(_ dataset: CIFAR100) { + var totalCount = 0 + for epochBatches in dataset.training.prefix(1){ + for batch in epochBatches { + XCTAssertTrue((0..<100).contains(batch.label[0].scalar!)) + XCTAssertEqual(batch.data.shape, [1, 32, 32, 3]) + totalCount += 1 + } + } + XCTAssertEqual(totalCount, 50000) + } + + func testNormalizeCIFAR100() { + let dataset = CIFAR100( + batchSize: 50000, + entropy: SystemRandomNumberGenerator(), + device: Device.default, + remoteBinaryArchiveLocation: + URL( + string: + "https://www.cs.toronto.edu/~kriz/cifar-100-binary.tar.gz" + )!, normalizing: true + ) + + let targetMean = Tensor([0, 0, 0]) + let targetStd = Tensor([1, 1, 1]) + for epochBatches in dataset.training.prefix(1){ + for batch in epochBatches { + let images = Tensor(batch.data) + let mean = images.mean(squeezingAxes: [0, 1, 2]) + let std = images.standardDeviation(squeezingAxes: [0, 1, 2]) + XCTAssertTrue(targetMean.isAlmostEqual(to: mean, + tolerance: 1e-3)) + XCTAssertTrue(targetStd.isAlmostEqual(to: std, + tolerance: 1e-3)) + } + } + } +} + +extension CIFAR100Tests { + static var allTests = [ + ("testCreateCIFAR100", testCreateCIFAR100), + ("testNormalizeCIFAR100", testNormalizeCIFAR100), + ] +} \ No newline at end of file diff --git a/Tests/DatasetsTests/CMakeLists.txt b/Tests/DatasetsTests/CMakeLists.txt index 1684b46eed8..86774b8ff70 100644 --- a/Tests/DatasetsTests/CMakeLists.txt +++ b/Tests/DatasetsTests/CMakeLists.txt @@ -1,6 +1,7 @@ add_library(DatasetsTests BostonHousing/BostonHousingTests.swift CIFAR10/CIFAR10Tests.swift + CIFAR100/CIFAR100Tests.swift COCO/COCODatasetTests.swift COCO/COCOVariantTests.swift Imagenette/ImagenetteTests.swift diff --git a/Tests/DatasetsTests/XCTestManifests.swift b/Tests/DatasetsTests/XCTestManifests.swift index 1a171f2f663..20a4a19fb5a 100644 --- a/Tests/DatasetsTests/XCTestManifests.swift +++ b/Tests/DatasetsTests/XCTestManifests.swift @@ -18,6 +18,7 @@ import XCTest public func allTests() -> [XCTestCaseEntry] { return [ testCase(CIFAR10Tests.allTests), + testCase(CIFAR100Tests.allTests), testCase(COCOVariantTests.allTests), testCase(COCODatasetTests.allTests), testCase(MNISTTests.allTests), diff --git a/Tests/ImageClassificationTests/Inference.swift b/Tests/ImageClassificationTests/Inference.swift index a4e23f534f4..bd315a9a659 100644 --- a/Tests/ImageClassificationTests/Inference.swift +++ b/Tests/ImageClassificationTests/Inference.swift @@ -22,6 +22,15 @@ final class ImageClassificationInferenceTests: XCTestCase { Context.local.learningPhase = .inference } + func testBigTransfer() { + let input = Tensor( + randomNormal: [1, 32, 32, 3], mean: Tensor(0.5), + standardDeviation: Tensor(0.1), seed: (0xffeffe, 0xfffe)) + let bigTransfer = BigTransfer(classCount: 1000, depth: .resNet50, loadWeights: false) + let bigTransferResult = bigTransfer(input) + XCTAssertEqual(bigTransferResult.shape, [1, 1000]) + } + func testDenseNet121() { let input = Tensor( randomNormal: [1, 224, 224, 3], mean: Tensor(0.5), @@ -385,6 +394,7 @@ final class ImageClassificationInferenceTests: XCTestCase { extension ImageClassificationInferenceTests { static var allTests = [ + ("testBigTransfer", testBigTransfer), ("testDenseNet121", testDenseNet121), ("testEfficientNet", testEfficientNet), ("testLeNet", testLeNet),