diff --git a/Support/FoundationFileSystem.swift b/Support/FoundationFileSystem.swift index 57ec0f96ddd..4cf4a2401d7 100644 --- a/Support/FoundationFileSystem.swift +++ b/Support/FoundationFileSystem.swift @@ -59,10 +59,8 @@ public struct FoundationFile: File { try value.write(to: location) } - /// Append data to the file. - /// - /// Parameter value: data to be appended at the end. - public func append(_ value: Data) throws { + /// Appends the bytes in `suffix` to the file. + public func append(_ suffix: Data) throws { let fileHandler = try FileHandle(forUpdating: location) #if os(macOS) // The following are needed in order to build on macOS 10.15 (Catalina). They can be removed @@ -72,7 +70,7 @@ public struct FoundationFile: File { fileHandler.closeFile() #else try fileHandler.seekToEnd() - try fileHandler.write(contentsOf: value) + try fileHandler.write(contentsOf: suffix) try fileHandler.close() #endif } diff --git a/TrainingLoop/Callbacks/CSVLogger.swift b/TrainingLoop/Callbacks/CSVLogger.swift index efdb9a28da5..50966ac4914 100644 --- a/TrainingLoop/Callbacks/CSVLogger.swift +++ b/TrainingLoop/Callbacks/CSVLogger.swift @@ -7,13 +7,13 @@ public enum CSVLoggerError: Error { /// A handler for logging training and validation statistics to a CSV file. public class CSVLogger { - /// The path of the file that statistics are logged to. + /// The path of the file to which statistics are logged. public var path: String // True iff the header of the CSV file has been written. fileprivate var headerWritten: Bool - /// Creates an instance that logs to a file with the given path. + /// Creates an instance that logs to a file with the given `path`. /// /// Throws: File system errors. public init(path: String = "run/log.csv") throws { @@ -32,7 +32,7 @@ public class CSVLogger { self.headerWritten = false } - /// Logs the statistics for the 'loop' when 'batchEnd' event happens; + /// Logs the statistics for `loop` when a `batchEnd` event happens; /// ignoring other events. /// /// Throws: File system errors. @@ -43,7 +43,6 @@ public class CSVLogger { let batchIndex = loop.batchIndex, let batchCount = loop.batchCount, let stats = loop.lastStatsLog else { - // No-Op if trainingLoop doesn't set the required values for stats logging. return } @@ -61,11 +60,18 @@ public class CSVLogger { } } + /// Writes a row of column names to the file. + /// + /// Column names are "epoch", "batch" and the `name` of each element of `stats`, + /// in that order. func writeHeader(stats: [(name: String, value: Float)]) throws { let header = (["epoch", "batch"] + stats.lazy.map { $0.name }).joined(separator: ", ") + "\n" try FoundationFile(path: path).append(header.data(using: .utf8)!) } + /// Appends a row of statistics log to file with the given value `epoch` for + /// "epoch" column, `batch` for "batch" column, and `value`s of `stats` for corresponding + /// columns indicated by `stats` `name`s. func writeDataRow(epoch: String, batch: String, stats: [(name: String, value: Float)]) throws { let dataRow = ([epoch, batch] + stats.lazy.map { String($0.value) }).joined(separator: ", ") + "\n" diff --git a/TrainingLoop/Callbacks/ProgressPrinter.swift b/TrainingLoop/Callbacks/ProgressPrinter.swift index f2ff48eb9f0..0867086e251 100644 --- a/TrainingLoop/Callbacks/ProgressPrinter.swift +++ b/TrainingLoop/Callbacks/ProgressPrinter.swift @@ -14,28 +14,37 @@ import Foundation -let progressBarLength = 30 - /// A handler for printing the training and validation progress. +/// +/// The progress includes epoch and batch index the training is currently +/// in, how many percentages of a full training/validation set has been done, +/// and metric statistics. public class ProgressPrinter { - /// Print training or validation progress in response of the 'event'. + /// Length of the complete progress bar measured in count of `=` signs. + public var progressBarLength: Int + + /// Creates an instance that prints training progress with the complete + /// progress bar to be `progressBarLength` characters long. + public init(progressBarLength: Int = 30) { + self.progressBarLength = progressBarLength + } + + /// Prints training or validation progress in response of the `event`. /// /// An example of the progress would be: /// Epoch 1/12 /// 468/468 [==============================] - loss: 0.4819 - accuracy: 0.8513 - /// 79/79 [==============================] - loss: 0.1520 - accuracy: 0.9521 - public func print(_ loop: inout L, event: TrainingLoopEvent) throws { + /// 58/79 [======================>.......] - loss: 0.1520 - accuracy: 0.9521 + public func printProgress(_ loop: inout L, event: TrainingLoopEvent) throws { switch event { case .epochStart: guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else { - // No-Op if trainingLoop doesn't set the required values for progress printing. return } - Swift.print("Epoch \(epochIndex + 1)/\(epochCount)") + print("Epoch \(epochIndex + 1)/\(epochCount)") case .batchEnd: guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else { - // No-Op if trainingLoop doesn't set the required values for progress printing. return } @@ -46,15 +55,15 @@ public class ProgressPrinter { stats = formatStats(lastStatsLog) } - Swift.print( + print( "\r\(batchIndex + 1)/\(batchCount) \(progressBar)\(stats)", terminator: "" ) fflush(stdout) case .epochEnd: - Swift.print("") + print("") case .validationStart: - Swift.print("") + print("") default: return } diff --git a/TrainingLoop/Callbacks/StatisticsRecorder.swift b/TrainingLoop/Callbacks/StatisticsRecorder.swift index 59628172577..7f33b494b90 100644 --- a/TrainingLoop/Callbacks/StatisticsRecorder.swift +++ b/TrainingLoop/Callbacks/StatisticsRecorder.swift @@ -17,31 +17,35 @@ import TensorFlow /// /// Data produced by this handler can be used by ProgressPrinter, CVSLogger, etc. public class StatisticsRecorder { - /// A Closure that returns if should call 'reset' on metricMeasurers. + /// A function that returns `true` iff recorder should call `reset` + /// on `metricMeasurers`. public var shouldReset: ( _ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int, _ event: TrainingLoopEvent ) -> Bool - /// A Closure that returns if should call 'accumulate' on metricMeasurers. + /// A function that returns `true` iff recorder should call `accumulate` + /// on `metricMeasurers`. public var shouldAccumulate: ( _ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int, _ event: TrainingLoopEvent ) -> Bool - /// A Closure that returns if should call 'compute' on metricMeasurers. + /// A function that returns `true` iff recorder should call `measure` + /// on `metricMeasurers`. public var shouldCompute: ( _ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int, _ event: TrainingLoopEvent ) -> Bool - /// Instances of MetricsMeasurers. + /// Instances of MetricsMeasurers that you can reset accumulate and compute + /// statistics periodically. fileprivate var metricMeasurers: [MetricsMeasurer] - /// Create an instance that records 'metrics' during the training loop. + /// Creates an instance that records `metrics` during the training loop. public init(metrics: [TrainingMetrics]) { metricMeasurers = metrics.map { $0.measurer } @@ -70,9 +74,9 @@ public class StatisticsRecorder { } } - /// Recording statistics in response of the 'event'. + /// Records statistics in response of the `event`. /// - /// It will record the statistics into 'lastStatsLog' in the loop where other + /// It will record the statistics into lastStatsLog property in the `loop` where other /// callbacks can consume from. public func record(_ loop: inout L, event: TrainingLoopEvent) throws { guard let batchIndex = loop.batchIndex, @@ -83,7 +87,6 @@ public class StatisticsRecorder { let output = loop.lastStepOutput, let target = loop.lastStepTarget else { - // No-Op if trainingLoop doesn't set the required values for stats recording. return } @@ -101,18 +104,22 @@ public class StatisticsRecorder { } } + /// Resets each of the metricMeasurers. func resetMetricMeasurers() { for index in metricMeasurers.indices { metricMeasurers[index].reset() } } + /// Lets each of the metricMeasurers accumulate data from + /// `loss`, `predictions`, `labels`. func accumulateMetrics(loss: Tensor, predictions: Output, labels: Target) { for index in metricMeasurers.indices { metricMeasurers[index].accumulate(loss: loss, predictions: predictions, labels: labels) } } + /// Lets each of the metricMeasurers compute metrics on cumulated data. func computeMetrics() -> [(String, Float)] { var result: [(String, Float)] = [] for measurer in metricMeasurers { diff --git a/TrainingLoop/Metrics.swift b/TrainingLoop/Metrics.swift index 48bc3a87be3..79585ea7f21 100644 --- a/TrainingLoop/Metrics.swift +++ b/TrainingLoop/Metrics.swift @@ -24,31 +24,46 @@ public enum TrainingMetrics { } } -/// A protocal defining functionalities of a metrics measurer. +/// An accumulator of statistics. public protocol MetricsMeasurer { + /// Name of the metrics. var name: String { get set } + + /// Clears accumulated data up and resets measurer to initial state. mutating func reset() + + /// Accumulates data from `loss`, `predictions`, `labels`. mutating func accumulate( - loss: Tensor?, predictions: Output?, labels: Target?) + loss: Tensor?, predictions: Output?, labels: Target? + ) + + /// Computes metrics from cumulated data. func measure() -> Float } /// A measurer for measuring loss. public struct LossMeasurer: MetricsMeasurer { + /// Name of the LossMeasurer. public var name: String + /// Sum of losses cumulated from batches. private var totalBatchLoss: Float = 0 + + /// Count of batchs cumulated so far. private var batchCount: Int32 = 0 + /// Creates an instance with the LossMeasurer named `name`. public init(_ name: String = "loss") { self.name = name } + /// Resets totalBatchLoss and batchCount to zero. public mutating func reset() { totalBatchLoss = 0 batchCount = 0 } + /// Adds `loss` to totalBatchLoss and increases batchCount by one. public mutating func accumulate( loss: Tensor?, predictions: Output?, labels: Target? ) { @@ -58,6 +73,7 @@ public struct LossMeasurer: MetricsMeasurer { } } + /// Computes averaged loss. public func measure() -> Float { return totalBatchLoss / Float(batchCount) } @@ -65,20 +81,29 @@ public struct LossMeasurer: MetricsMeasurer { /// A measurer for measuring accuracy public struct AccuracyMeasurer: MetricsMeasurer { + /// Name of the AccuracyMeasurer. public var name: String + /// Count of correct guesses. private var correctGuessCount: Int32 = 0 + + /// Count of total guesses. private var totalGuessCount: Int32 = 0 + /// Creates an instance with the AccuracyMeasurer named `name`. public init(_ name: String = "accuracy") { self.name = name } + /// Resets correctGuessCount and totalGuessCount to zero. public mutating func reset() { correctGuessCount = 0 totalGuessCount = 0 } + /// Computes correct guess count from `loss`, `predictions` and `labels` + /// and adds it to correctGuessCount; Computes total guess count from + /// `labels` shape and adds it to totalGuessCount. public mutating func accumulate( loss: Tensor?, predictions: Output?, labels: Target? ) { @@ -93,6 +118,7 @@ public struct AccuracyMeasurer: MetricsMeasurer { totalGuessCount += Int32(labels.shape.reduce(1, *)) } + /// Computes accuracy as percentage of correct guesses. public func measure() -> Float { return Float(correctGuessCount) / Float(totalGuessCount) } diff --git a/TrainingLoop/TrainingLoop.swift b/TrainingLoop/TrainingLoop.swift index 5585b2d6014..0f0794a8be8 100644 --- a/TrainingLoop/TrainingLoop.swift +++ b/TrainingLoop/TrainingLoop.swift @@ -78,7 +78,7 @@ public protocol TrainingLoopProtocol { /// The loss function. var lossFunction: LossFunction { get set } - /// The metrics + /// The metrics on which training is measured. var metrics: [TrainingMetrics] { get set } // Callbacks @@ -220,14 +220,15 @@ where /// Callbacks - // MARK: - The callbacks used to customize the training loop. - + /// The callbacks used to customize the training loop. public var callbacks: [TrainingLoopCallback] // MARK: - Default callback objects + /// The callback that records the training statistics. public var statisticsRecorder: StatisticsRecorder? = nil + /// The callback that prints the training progress. public var progressPrinter: ProgressPrinter? = nil /// Temporary data @@ -292,7 +293,7 @@ where self.progressPrinter = progressPrinter self.callbacks = [ statisticsRecorder.record, - progressPrinter.print, + progressPrinter.printProgress, ] + callbacks } else { self.callbacks = callbacks