diff --git a/Examples/LeNet-MNIST/main.swift b/Examples/LeNet-MNIST/main.swift index 5e35093eff6..c944d03fd7a 100644 --- a/Examples/LeNet-MNIST/main.swift +++ b/Examples/LeNet-MNIST/main.swift @@ -31,24 +31,33 @@ let dataset = MNIST(batchSize: batchSize, on: device) // The LeNet-5 model, equivalent to `LeNet` in `ImageClassificationModels`. var classifier = Sequential { - Conv2D(filterShape: (5, 5, 1, 6), padding: .same, activation: relu) - AvgPool2D(poolSize: (2, 2), strides: (2, 2)) - Conv2D(filterShape: (5, 5, 6, 16), activation: relu) - AvgPool2D(poolSize: (2, 2), strides: (2, 2)) - Flatten() - Dense(inputSize: 400, outputSize: 120, activation: relu) - Dense(inputSize: 120, outputSize: 84, activation: relu) - Dense(inputSize: 84, outputSize: 10) + Conv2D(filterShape: (5, 5, 1, 6), padding: .same, activation: relu) + AvgPool2D(poolSize: (2, 2), strides: (2, 2)) + Conv2D(filterShape: (5, 5, 6, 16), activation: relu) + AvgPool2D(poolSize: (2, 2), strides: (2, 2)) + Flatten() + Dense(inputSize: 400, outputSize: 120, activation: relu) + Dense(inputSize: 120, outputSize: 84, activation: relu) + Dense(inputSize: 84, outputSize: 10) } var optimizer = SGD(for: classifier, learningRate: 0.1) -let trainingProgress = TrainingProgress() var trainingLoop = TrainingLoop( training: dataset.training, validation: dataset.validation, optimizer: optimizer, lossFunction: softmaxCrossEntropy, - callbacks: [trainingProgress.update]) + metrics: [.accuracy], + callbacks: [try! CSVLogger().log]) + +// Compute statistics only when last batch ends. +trainingLoop.statisticsRecorder!.shouldCompute = { + ( + _ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int, + _ event: TrainingLoopEvent + ) -> Bool in + return event == .batchEnd && batchIndex + 1 == batchCount +} try! trainingLoop.fit(&classifier, epochs: epochCount, on: device) diff --git a/Examples/MobileNetV1-Imagenette/main.swift b/Examples/MobileNetV1-Imagenette/main.swift index 923d7d309dc..9a911b7e75e 100644 --- a/Examples/MobileNetV1-Imagenette/main.swift +++ b/Examples/MobileNetV1-Imagenette/main.swift @@ -29,12 +29,11 @@ let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224, var model = MobileNetV1(classCount: 10) let optimizer = SGD(for: model, learningRate: 0.02, momentum: 0.9) -let trainingProgress = TrainingProgress() var trainingLoop = TrainingLoop( training: dataset.training, validation: dataset.validation, optimizer: optimizer, lossFunction: softmaxCrossEntropy, - callbacks: [trainingProgress.update]) + metrics: [.accuracy]) try! trainingLoop.fit(&model, epochs: 10, on: device) diff --git a/Examples/MobileNetV2-Imagenette/main.swift b/Examples/MobileNetV2-Imagenette/main.swift index c2177934851..d35fa2f3198 100644 --- a/Examples/MobileNetV2-Imagenette/main.swift +++ b/Examples/MobileNetV2-Imagenette/main.swift @@ -29,12 +29,11 @@ let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224, var model = MobileNetV2(classCount: 10) let optimizer = SGD(for: model, learningRate: 0.002, momentum: 0.9) -let trainingProgress = TrainingProgress() var trainingLoop = TrainingLoop( training: dataset.training, validation: dataset.validation, optimizer: optimizer, lossFunction: softmaxCrossEntropy, - callbacks: [trainingProgress.update]) + metrics: [.accuracy]) try! trainingLoop.fit(&model, epochs: 10, on: device) diff --git a/Examples/ResNet-CIFAR10/main.swift b/Examples/ResNet-CIFAR10/main.swift index 893d2cc0f5f..5253cd16a0c 100644 --- a/Examples/ResNet-CIFAR10/main.swift +++ b/Examples/ResNet-CIFAR10/main.swift @@ -29,12 +29,11 @@ let dataset = CIFAR10(batchSize: 10, on: device) var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false) var optimizer = SGD(for: model, learningRate: 0.001) -let trainingProgress = TrainingProgress() var trainingLoop = TrainingLoop( training: dataset.training, validation: dataset.validation, optimizer: optimizer, lossFunction: softmaxCrossEntropy, - callbacks: [trainingProgress.update]) + metrics: [.accuracy]) try! trainingLoop.fit(&model, epochs: 10, on: device) diff --git a/Examples/VGG-Imagewoof/main.swift b/Examples/VGG-Imagewoof/main.swift index 05327be7513..83ebc3c4761 100644 --- a/Examples/VGG-Imagewoof/main.swift +++ b/Examples/VGG-Imagewoof/main.swift @@ -39,12 +39,12 @@ public func scheduleLearningRate( } } -let trainingProgress = TrainingProgress() var trainingLoop = TrainingLoop( training: dataset.training, validation: dataset.validation, optimizer: optimizer, lossFunction: softmaxCrossEntropy, - callbacks: [trainingProgress.update, scheduleLearningRate]) + metrics: [.accuracy], + callbacks: [scheduleLearningRate]) try! trainingLoop.fit(&model, epochs: 90, on: device) diff --git a/Support/FileSystem.swift b/Support/FileSystem.swift index 3441754e737..617ba4dfdd0 100644 --- a/Support/FileSystem.swift +++ b/Support/FileSystem.swift @@ -39,4 +39,5 @@ public protocol File { func read(position: Int, count: Int) throws -> Data func write(_ value: Data) throws func write(_ value: Data, position: Int) throws + func append(_ value: Data) throws } diff --git a/Support/FoundationFileSystem.swift b/Support/FoundationFileSystem.swift index a3db73ae3c7..2b90fc61134 100644 --- a/Support/FoundationFileSystem.swift +++ b/Support/FoundationFileSystem.swift @@ -58,4 +58,14 @@ public struct FoundationFile: File { // TODO: Incorporate file offset. try value.write(to: location) } + + /// Append data to the file. + /// + /// Parameter value: data to be appended at the end. + public func append(_ value: Data) throws { + let fileHandler = try FileHandle(forUpdating: location) + try fileHandler.seekToEnd() + try fileHandler.write(contentsOf: value) + try fileHandler.close() + } } diff --git a/TrainingLoop/CMakeLists.txt b/TrainingLoop/CMakeLists.txt index 047d96056af..f533f37b463 100644 --- a/TrainingLoop/CMakeLists.txt +++ b/TrainingLoop/CMakeLists.txt @@ -1,8 +1,10 @@ add_library(TrainingLoop LossFunctions.swift + Metrics.swift TrainingLoop.swift - TrainingProgress.swift - TrainingStatistics.swift) + Callbacks/StatisticsRecorder.swift + Callbacks/ProgressPrinter.swift + Callbacks/CSVLogger.swift) target_link_libraries(TrainingLoop PUBLIC ModelSupport) set_target_properties(TrainingLoop PROPERTIES diff --git a/TrainingLoop/Callbacks/CSVLogger.swift b/TrainingLoop/Callbacks/CSVLogger.swift new file mode 100644 index 00000000000..efdb9a28da5 --- /dev/null +++ b/TrainingLoop/Callbacks/CSVLogger.swift @@ -0,0 +1,74 @@ +import Foundation +import ModelSupport + +public enum CSVLoggerError: Error { + case InvalidPath +} + +/// A handler for logging training and validation statistics to a CSV file. +public class CSVLogger { + /// The path of the file that statistics are logged to. + public var path: String + + // True iff the header of the CSV file has been written. + fileprivate var headerWritten: Bool + + /// Creates an instance that logs to a file with the given path. + /// + /// Throws: File system errors. + public init(path: String = "run/log.csv") throws { + self.path = path + + // Validate the path. + let url = URL(fileURLWithPath: path) + if url.pathExtension != "csv" { + throw CSVLoggerError.InvalidPath + } + // Create the containing directory if it is missing. + try FoundationFileSystem().createDirectoryIfMissing(at: url.deletingLastPathComponent().path) + // Initialize the file with empty string. + try FoundationFile(path: path).write(Data()) + + self.headerWritten = false + } + + /// Logs the statistics for the 'loop' when 'batchEnd' event happens; + /// ignoring other events. + /// + /// Throws: File system errors. + public func log(_ loop: inout L, event: TrainingLoopEvent) throws { + switch event { + case .batchEnd: + guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount, + let batchIndex = loop.batchIndex, let batchCount = loop.batchCount, + let stats = loop.lastStatsLog + else { + // No-Op if trainingLoop doesn't set the required values for stats logging. + return + } + + if !headerWritten { + try writeHeader(stats: stats) + headerWritten = true + } + + try writeDataRow( + epoch: "\(epochIndex + 1)/\(epochCount)", + batch: "\(batchIndex + 1)/\(batchCount)", + stats: stats) + default: + return + } + } + + func writeHeader(stats: [(name: String, value: Float)]) throws { + let header = (["epoch", "batch"] + stats.lazy.map { $0.name }).joined(separator: ", ") + "\n" + try FoundationFile(path: path).append(header.data(using: .utf8)!) + } + + func writeDataRow(epoch: String, batch: String, stats: [(name: String, value: Float)]) throws { + let dataRow = ([epoch, batch] + stats.lazy.map { String($0.value) }).joined(separator: ", ") + + "\n" + try FoundationFile(path: path).append(dataRow.data(using: .utf8)!) + } +} diff --git a/TrainingLoop/Callbacks/ProgressPrinter.swift b/TrainingLoop/Callbacks/ProgressPrinter.swift new file mode 100644 index 00000000000..f2ff48eb9f0 --- /dev/null +++ b/TrainingLoop/Callbacks/ProgressPrinter.swift @@ -0,0 +1,85 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +import Foundation + +let progressBarLength = 30 + +/// A handler for printing the training and validation progress. +public class ProgressPrinter { + /// Print training or validation progress in response of the 'event'. + /// + /// An example of the progress would be: + /// Epoch 1/12 + /// 468/468 [==============================] - loss: 0.4819 - accuracy: 0.8513 + /// 79/79 [==============================] - loss: 0.1520 - accuracy: 0.9521 + public func print(_ loop: inout L, event: TrainingLoopEvent) throws { + switch event { + case .epochStart: + guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else { + // No-Op if trainingLoop doesn't set the required values for progress printing. + return + } + + Swift.print("Epoch \(epochIndex + 1)/\(epochCount)") + case .batchEnd: + guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else { + // No-Op if trainingLoop doesn't set the required values for progress printing. + return + } + + let progressBar = formatProgressBar( + progress: Float(batchIndex + 1) / Float(batchCount), length: progressBarLength) + var stats: String = "" + if let lastStatsLog = loop.lastStatsLog { + stats = formatStats(lastStatsLog) + } + + Swift.print( + "\r\(batchIndex + 1)/\(batchCount) \(progressBar)\(stats)", + terminator: "" + ) + fflush(stdout) + case .epochEnd: + Swift.print("") + case .validationStart: + Swift.print("") + default: + return + } + } + + func formatProgressBar(progress: Float, length: Int) -> String { + let progressSteps = Int(round(Float(length) * progress)) + let leading = String(repeating: "=", count: progressSteps) + let separator: String + let trailing: String + if progressSteps < progressBarLength { + separator = ">" + trailing = String(repeating: ".", count: progressBarLength - progressSteps - 1) + } else { + separator = "" + trailing = "" + } + return "[\(leading)\(separator)\(trailing)]" + } + + func formatStats(_ stats: [(String, Float)]) -> String { + var result = "" + for stat in stats { + result += " - \(stat.0): \(String(format: "%.4f", stat.1))" + } + return result + } +} diff --git a/TrainingLoop/Callbacks/StatisticsRecorder.swift b/TrainingLoop/Callbacks/StatisticsRecorder.swift new file mode 100644 index 00000000000..6c98c6887c0 --- /dev/null +++ b/TrainingLoop/Callbacks/StatisticsRecorder.swift @@ -0,0 +1,123 @@ +// Copyright 2020 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +import TensorFlow + +/// A handler for recording training and validation statistics. +/// +/// Data produced by this handler can be used by ProgressPrinter, CVSLogger, etc. +public class StatisticsRecorder { + /// A Closure that returns if should call 'reset' on metricMeasurers. + public var shouldReset: + ( + _ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int, + _ event: TrainingLoopEvent + ) -> Bool + + /// A Closure that returns if should call 'accumulate' on metricMeasurers. + public var shouldAccumulate: + ( + _ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int, + _ event: TrainingLoopEvent + ) -> Bool + + /// A Closure that returns if should call 'compute' on metricMeasurers. + public var shouldCompute: + ( + _ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int, + _ event: TrainingLoopEvent + ) -> Bool + + /// Instances of MetricsMeasurers. + fileprivate var metricMeasurers: [MetricsMeasurer] + + /// Create an instance that records 'metrics' during the training loop. + public init(metrics: [TrainingMetrics]) { + metricMeasurers = metrics.map { $0.measurer } + + shouldReset = { + ( + _ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int, + _ event: TrainingLoopEvent + ) -> Bool in + return event == .trainingStart || event == .validationStart + } + + shouldAccumulate = { + ( + _ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int, + _ event: TrainingLoopEvent + ) -> Bool in + return event == .batchEnd + } + + shouldCompute = { + ( + _ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int, + _ event: TrainingLoopEvent + ) -> Bool in + return event == .batchEnd + } + } + + /// Recording statistics in response of the 'event'. + /// + /// It will record the statistics into 'lastStatsLog' in the loop where other + /// callbacks can consume from. + public func record(_ loop: inout L, event: TrainingLoopEvent) throws { + guard let batchIndex = loop.batchIndex, + let batchCount = loop.batchCount, + let epochIndex = loop.batchIndex, + let epochCount = loop.epochCount, + let loss = loop.lastStepLoss, + let output = loop.lastStepOutput, + let target = loop.lastStepTarget + else { + // No-Op if trainingLoop doesn't set the required values for stats recording. + return + } + + if shouldReset(batchIndex, batchCount, epochIndex, epochCount, event) { + resetMetricMeasurers() + loop.lastStatsLog = nil + } + + if shouldAccumulate(batchIndex, batchCount, epochIndex, epochCount, event) { + accumulateMetrics(loss: loss, predictions: output, labels: target) + } + + if shouldCompute(batchIndex, batchCount, epochIndex, epochCount, event) { + loop.lastStatsLog = computeMetrics() + } + } + + func resetMetricMeasurers() { + for index in metricMeasurers.indices { + metricMeasurers[index].reset() + } + } + + func accumulateMetrics(loss: Tensor, predictions: Output, labels: Target) { + for index in metricMeasurers.indices { + metricMeasurers[index].accumulate(loss: loss, predictions: predictions, labels: labels) + } + } + + func computeMetrics() -> [(String, Float)] { + var result: [(String, Float)] = [] + for measurer in metricMeasurers { + result.append((name: measurer.name, value: measurer.measure())) + } + return result + } +} diff --git a/TrainingLoop/Metrics.swift b/TrainingLoop/Metrics.swift new file mode 100644 index 00000000000..d7f35c6fc4a --- /dev/null +++ b/TrainingLoop/Metrics.swift @@ -0,0 +1,100 @@ +import TensorFlow + +/// Metrics that can be registered into TrainingLoop. +public enum TrainingMetrics { + case loss + case accuracy + + public var name: String { + switch self { + case .loss: + return "loss" + case .accuracy: + return "accuracy" + } + } + + public var measurer: MetricsMeasurer { + switch self { + case .loss: + return LossMeasurer(self.name) + case .accuracy: + return AccuracyMeasurer(self.name) + } + } +} + +/// A protocal defining functionalities of a metrics measurer. +public protocol MetricsMeasurer { + var name: String { get set } + mutating func reset() + mutating func accumulate( + loss: Tensor?, predictions: Output?, labels: Target? + ) + func measure() -> Float +} + +/// A measurer for measuring loss. +public struct LossMeasurer: MetricsMeasurer { + public var name: String + + private var totalBatchLoss: Float = 0 + private var batchCount: Int32 = 0 + + public init(_ name: String = "loss") { + self.name = name + } + + public mutating func reset() { + totalBatchLoss = 0 + batchCount = 0 + } + + public mutating func accumulate( + loss: Tensor?, predictions: Output?, labels: Target? + ) { + if let newBatchLoss = loss { + totalBatchLoss += newBatchLoss.scalarized() + batchCount += 1 + } + } + + public func measure() -> Float { + return totalBatchLoss / Float(batchCount) + } +} + +/// A measurer for measuring accuracy +public struct AccuracyMeasurer: MetricsMeasurer { + public var name: String + + private var correctGuessCount: Int32 = 0 + private var totalGuessCount: Int32 = 0 + + public init(_ name: String = "accuracy") { + self.name = name + } + + public mutating func reset() { + correctGuessCount = 0 + totalGuessCount = 0 + } + + public mutating func accumulate( + loss: Tensor?, predictions: Output?, labels: Target? + ) { + guard let predictions = predictions as? Tensor, let labels = labels as? Tensor + else { + fatalError( + "For accuracy measurements, the model output must be Tensor, and the labels must be Tensor." + ) + } + correctGuessCount += Tensor(predictions.argmax(squeezingAxis: 1) .== labels).sum() + .scalarized() + totalGuessCount += Int32(labels.shape[0]) + } + + public func measure() -> Float { + return Float(correctGuessCount) / Float(totalGuessCount) + } +} diff --git a/TrainingLoop/TrainingLoop.swift b/TrainingLoop/TrainingLoop.swift index fa1953ad2f1..5585b2d6014 100644 --- a/TrainingLoop/TrainingLoop.swift +++ b/TrainingLoop/TrainingLoop.swift @@ -18,7 +18,7 @@ import TensorFlow // Workaround https://bugs.swift.org/browse/TF-1122 that prevents us from registering a // loss function inside our TrainingLoop struct public final class LossFunctionWrapper { - public typealias F = @differentiable (Output, @noDerivative Target) -> Tensor + public typealias F = @differentiable(Output, @noDerivative Target) -> Tensor public var f: F init(_ f: @escaping F) { self.f = f } } @@ -34,25 +34,32 @@ public protocol TrainingLoopProtocol { where Training: Sequence, Training.Element: Collection, Training.Element.Element == LabeledData + /// The type of the collection of batches for the validation data. associatedtype Validation where Validation: Collection, Validation.Element == LabeledData + /// The type of the target of our model. associatedtype Target + /// The type of the optimizer used. associatedtype Opt: Optimizer where Opt.Model: Module // Typealiases /// The type of the model. typealias Model = Opt.Model + /// The type of the input of the model. typealias Input = Opt.Model.Input + /// The type of the output of the model. typealias Output = Opt.Model.Output + /// The type of a batch. typealias Batch = LabeledData + // In a wrapper for now because of TF-1122. /// The type of the loss function. typealias LossFunction = LossFunctionWrapper @@ -60,38 +67,61 @@ public protocol TrainingLoopProtocol { // Data /// The training epochs. var training: Training { get } + /// The validation batches. var validation: Validation { get } // Optimizer and loss function /// The optimizer. var optimizer: Opt { get set } + /// The loss function. var lossFunction: LossFunction { get set } + /// The metrics + var metrics: [TrainingMetrics] { get set } + // Callbacks /// The callbacks used to customize the training loop. var callbacks: [TrainingLoopCallback] { get set } // Temporary data + + // MARK: - Step-level data + /// The last input fed to the model. - var lastInput: Input? { get set } + var lastStepInput: Input? { get set } + /// The last target. - var lastTarget: Target? { get set } + var lastStepTarget: Target? { get set } + /// The last predictions of the model. - var lastOutput: Output? { get set } + var lastStepOutput: Output? { get set } + /// The last gradients computed. - var lastGradient: Model.TangentVector? { get set } + var lastStepGradient: Model.TangentVector? { get set } + /// The last loss. - var lastLoss: Tensor? { get set } - /// The number of epochs we are currently fitting for. - var epochCount: Int? { get set } - /// The index of the current epoch. - var epochIndex: Int? { get set } + var lastStepLoss: Tensor? { get set } + /// The number of batches in the current collection of batches. var batchCount: Int? { get set } + /// The index of the current batch. var batchIndex: Int? { get set } + + // MARK: - Epoch-level data + + /// The number of epochs we are currently fitting for. + var epochCount: Int? { get set } + + /// The index of the current epoch. + var epochIndex: Int? { get set } + + // MARK: - Others + + /// The log for last statistics + var lastStatsLog: [(name: String, value: Float)]? { get set } } /// The events that occur during a call to `fit` in the `TrainingLoop` @@ -101,26 +131,37 @@ public protocol TrainingLoopProtocol { public enum TrainingLoopEvent { /// The start of a fit. case fitStart + /// The end of a fit. case fitEnd + /// The start of one epoch (training + validation). case epochStart + /// The start of one epoch (training + validation). case epochEnd + /// The start of a training phase. case trainingStart + /// The end of a training phase. case trainingEnd + /// The start of a validation phase. case validationStart + /// The end of a validation phase. case validationEnd + /// The start of a training or inference step on a batch. case batchStart + /// The end of a training or inference step on a batch. case batchEnd + /// At the start of the optimizer update, just after the differentiable step. case updateStart + /// Just after the model prediction at inference, before computing the loss. case inferencePredictionEnd } @@ -146,12 +187,16 @@ where // Typealiases /// The type of the model. public typealias Model = Opt.Model + /// The type of the input of the model. public typealias Input = Opt.Model.Input + /// The type of the output of the model. public typealias Output = Opt.Model.Output + /// The type of a batch. public typealias Batch = LabeledData + // In a wrapper for now because of TF-1122. /// The type of the loss function. public typealias LossFunction = LossFunctionWrapper @@ -159,74 +204,122 @@ where // Data /// The training epochs. public let training: Training + /// The validation batches. public let validation: Validation // Optimizer and loss function /// The optimizer. public var optimizer: Opt + /// The loss function public var lossFunction: LossFunction - // Callbacks - /// The callbacks used to customize the training loop. - public var callbacks: [TrainingLoopCallback] = [] + /// The metrics + public var metrics: [TrainingMetrics] + + /// Callbacks + + // MARK: - The callbacks used to customize the training loop. + + public var callbacks: [TrainingLoopCallback] + + // MARK: - Default callback objects + + public var statisticsRecorder: StatisticsRecorder? = nil + + public var progressPrinter: ProgressPrinter? = nil + + /// Temporary data + + // MARK: - Step-level data - // Temporary data /// The last input fed to the model. - public var lastInput: Input? = nil + public var lastStepInput: Input? = nil + /// The last target. - public var lastTarget: Target? = nil + public var lastStepTarget: Target? = nil + /// The last predictions of the model. - public var lastOutput: Output? = nil + public var lastStepOutput: Output? = nil + /// The last gradients computed. - public var lastGradient: Model.TangentVector? = nil + public var lastStepGradient: Model.TangentVector? = nil + /// The last loss. - public var lastLoss: Tensor? = nil - /// The number of epochs we are currently fitting for. - public var epochCount: Int? = nil - /// The index of the current epoch. - public var epochIndex: Int? = nil + public var lastStepLoss: Tensor? = nil + /// The number of batches in the current collection of batches. public var batchCount: Int? = nil + /// The index of the current batch. public var batchIndex: Int? = nil + // MARK: - Epoch-level data + + /// The number of epochs we are currently fitting for. + public var epochCount: Int? = nil + + /// The index of the current epoch. + public var epochIndex: Int? = nil + + // MARK: - Others + + /// The log for last statistics + public var lastStatsLog: [(name: String, value: Float)]? = nil + /// Creates an instance from `training` and `validation` data, a `model`, an `optimizer` and a /// `lossFunction`. /// /// Parameter callbacks: Callbacks that the `TrainingLoop` will use in every call to fit. public init( training: Training, validation: Validation, optimizer: Opt, - lossFunction: @escaping LossFunction.F, callbacks: [TrainingLoopCallback] = [] + lossFunction: @escaping LossFunction.F, + metrics: [TrainingMetrics] = [], + callbacks: [TrainingLoopCallback] = [], + includeDefaultCallbacks: Bool = true ) { self.training = training self.validation = validation self.optimizer = optimizer self.lossFunction = LossFunction(lossFunction) - self.callbacks = callbacks + self.metrics = metrics + + if includeDefaultCallbacks { + let statisticsRecorder = StatisticsRecorder(metrics: [.loss] + metrics) + let progressPrinter = ProgressPrinter() + self.statisticsRecorder = statisticsRecorder + self.progressPrinter = progressPrinter + self.callbacks = [ + statisticsRecorder.record, + progressPrinter.print, + ] + callbacks + } else { + self.callbacks = callbacks + } } } extension TrainingLoop { /// The default differentiable step. public mutating func differentiableStep(model: Model) throws { - guard let data = lastInput else { return } - guard let target = lastTarget else { return } - (lastLoss, lastGradient) = valueWithGradient(at: model) { (model: Model) -> Tensor in + guard let data = lastStepInput else { return } + guard let target = lastStepTarget else { return } + (lastStepLoss, lastStepGradient) = valueWithGradient(at: model) { + (model: Model) -> Tensor in let predictions = model(data) - lastOutput = predictions + lastStepOutput = predictions return lossFunction.f(predictions, target) } } /// The step used for inference. public mutating func inferenceStep(model: Model) throws { - guard let data = lastInput else { return } - lastOutput = model(data) - guard let target = lastTarget else { return } + guard let data = lastStepInput else { return } + lastStepOutput = model(data) + guard let target = lastStepTarget else { return } try handleEvent(.inferencePredictionEnd) - lastLoss = lossFunction.f(lastOutput!, target) + lastStepLoss = lossFunction.f(lastStepOutput!, target) } /// The step used for training. @@ -235,7 +328,7 @@ extension TrainingLoop { ) throws { try differentiableStep(model, &self) try handleEvent(.updateStart) - optimizer.update(&model, along: lastGradient!) + optimizer.update(&model, along: lastStepGradient!) } } @@ -245,12 +338,16 @@ extension TrainingLoop { public enum TrainingLoopAction: Error { /// Abort actions in the current training/inference step and goes to the next batch. case cancelBatch + /// Abort actions in the current training phase and goes to the validation phase. case cancelTraining + /// Abort actions in the current validation phase and goes to the next epoch. case cancelValidation + /// Abort actions in the current epoch and goes to the next epoch. case cancelEpoch + /// Abort actions in the current fit and ends fitting. case cancelFit } @@ -272,7 +369,7 @@ extension TrainingLoop { batchCount = batches.count for (i, batch) in batches.enumerated() { batchIndex = i - (lastInput, lastTarget) = (batch.data, batch.label) + (lastStepInput, lastStepTarget) = (batch.data, batch.label) do { try handleEvent(.batchStart) try step(&self) @@ -294,7 +391,9 @@ extension TrainingLoop { public mutating func fit( _ model: inout Model, epochs: Int, callbacks: [TrainingLoopCallback] = [], on device: Device = Device.default, - differentiableStep: (Model, inout Self) throws -> Void = { try $1.differentiableStep(model: $0) } + differentiableStep: (Model, inout Self) throws -> Void = { + try $1.differentiableStep(model: $0) + } ) throws { let callbacksCount = self.callbacks.count self.callbacks += callbacks diff --git a/TrainingLoop/TrainingProgress.swift b/TrainingLoop/TrainingProgress.swift deleted file mode 100644 index 208760a5ec5..00000000000 --- a/TrainingLoop/TrainingProgress.swift +++ /dev/null @@ -1,110 +0,0 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import Foundation - -let progressBarLength = 30 - -/// A progress bar that displays to the console as a model trains, and as validation is performed. -/// It hooks into a TrainingLoop via a callback method. -public class TrainingProgress { - var statistics: TrainingStatistics? - let metrics: Set - let liveStatistics: Bool - - /// Initializes the progress bar with the metrics to be displayed (if any), and whether to - /// provide a live update of training and validation metrics as they are calculated. - /// - /// - Parameters: - /// - metrics: A set of TrainingMetrics that specify which metrics to monitor and display - /// during training and validation. By default, all available metrics are selected. - /// - liveStatistics: Whether or not to update the metrics at the command line on every batch - /// as it is processed, or if these values should just be provided at the end of an epoch. - /// This has an impact on performance, due to materialization of tensors, and updating values - /// on every batch can reduce training speed by up to 30%. - public init(metrics: Set = [.accuracy, .loss], liveStatistics: Bool = true) { - self.metrics = metrics - self.liveStatistics = liveStatistics - if !metrics.isEmpty { - statistics = TrainingStatistics(metrics: metrics) - } - } - - func progressBar(progress: Float, length: Int) -> String { - let progressSteps = Int(round(Float(length) * progress)) - let leading = String(repeating: "=", count: progressSteps) - let separator: String - let trailing: String - if progressSteps < progressBarLength { - separator = ">" - trailing = String(repeating: ".", count: progressBarLength - progressSteps - 1) - } else { - separator = "" - trailing = "" - } - return "[\(leading)\(separator)\(trailing)]" - } - - func metricDescription() -> String { - var result: String = "" - if metrics.contains(.loss) { - result += " - loss: \(String(format: "%.4f", statistics!.averageLoss()))" - } - if metrics.contains(.accuracy) { - result += " - accuracy: \(String(format: "%.4f", statistics!.accuracy()))" - } - - return result - } - - /// The callback used to hook into the TrainingLoop. This is updated once per event. - /// - /// - Parameters: - /// - loop: The TrainingLoop where an event has occurred. This can be accessed to obtain - /// the last measure loss and other values. - /// - event: The training or validation event that this callback is responding to. - public func update(_ loop: inout L, event: TrainingLoopEvent) throws { - try statistics?.record(&loop, event: event) - - switch event { - case .epochStart: - guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else { - return - } - print("Epoch \(epochIndex + 1)/\(epochCount)") - case .batchEnd: - guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else { - return - } - let epochProgress = Float(batchIndex + 1) / Float(batchCount) - let progressBarComponent = progressBar(progress: epochProgress, length: progressBarLength) - let metricDescriptionComponent: String - if liveStatistics || (batchCount == (batchIndex + 1)) { - metricDescriptionComponent = metricDescription() - } else { - metricDescriptionComponent = "" - } - print( - "\r\(batchIndex + 1)/\(batchCount) \(progressBarComponent)\(metricDescriptionComponent)", - terminator: "" - ) - fflush(stdout) - case .epochEnd: - print("") - case .validationStart: - print("") - default: break - } - } -} diff --git a/TrainingLoop/TrainingStatistics.swift b/TrainingLoop/TrainingStatistics.swift deleted file mode 100644 index c3380a2088d..00000000000 --- a/TrainingLoop/TrainingStatistics.swift +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2020 The TensorFlow Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -import TensorFlow - -/// Metrics that can be tracked or displayed during training or validation. -public enum TrainingMetrics { - case accuracy - case loss -} - -/// A callback-based handler of statistics obtained during a training loop. This can be employed -/// by progress bars, recorders, or logging functionality. -public class TrainingStatistics { - let metrics: Set - var totalBatchLoss: Tensor? - var totalBatches: Tensor? - var totalCorrect: Tensor? - var totalExamples: Int32? - - /// Initializes the statistics tracker with - /// - /// - Parameters: - /// - metrics: A set of TrainingMetrics to capture during the training loop. - public init(metrics: Set) { - self.metrics = metrics - } - - /// The current average loss, calculated from the batches seen since the previous start of - /// training or validation. - public func averageLoss() -> Float { - guard let totalBatches = totalBatches, let totalBatchLoss = totalBatchLoss else { - return Float.nan - } - return (totalBatchLoss / totalBatches).scalarized() - } - - /// The current accuracy, calculated from the batches seen since the previous start of - /// training or validation. Not all models support class-based accuracy as a metric. - public func accuracy() -> Float { - guard let totalCorrect = totalCorrect, let totalExamples = totalExamples else { - return Float.nan - } - return Float(totalCorrect.scalarized()) / Float(totalExamples) - } - - /// The callback used to hook into the TrainingLoop. This is updated once per event. - /// - /// - Parameters: - /// - loop: The TrainingLoop where an event has occurred. This can be accessed to obtain - /// the last measure loss and other values. - /// - event: The training or validation event that this callback is responding to. - public func record(_ loop: inout L, event: TrainingLoopEvent) throws { - switch event { - case .trainingStart, .validationStart: - totalBatchLoss = nil - totalBatches = nil - totalCorrect = nil - totalExamples = nil - case .batchEnd: - if metrics.contains(.accuracy) { - measureAccuracy(loop) - } - - if let loss = loop.lastLoss, metrics.contains(.loss) { - if let currentTotalBatchLoss = totalBatchLoss { - totalBatchLoss = currentTotalBatchLoss + loss - totalBatches = totalBatches! + 1.0 - } else { - totalBatchLoss = loss - totalBatches = Tensor(1.0, on: loss.device) - } - } - default: - return - } - } - - func measureAccuracy(_ loop: L) { - guard let possibleOutput = loop.lastOutput, let possibleTarget = loop.lastTarget else { return } - guard let output = possibleOutput as? Tensor, - let target = possibleTarget as? Tensor else { - fatalError( - "For accuracy measurements, the model output must be Tensor, and the labels must be Tensor.") - } - - let correct = output.argmax(squeezingAxis: 1) .== target - let correctGuessCount = Tensor(correct).sum() - if let currentTotalCorrect = totalCorrect { - totalCorrect = currentTotalCorrect + correctGuessCount - totalExamples = totalExamples! + Int32(output.shape[0]) - } else { - totalCorrect = correctGuessCount - totalExamples = Int32(output.shape[0]) - } - } -}