Skip to content
This repository was archived by the owner on Apr 23, 2025. It is now read-only.

Address doc comments in TrainingLoop Callbacks #670

Merged
merged 4 commits into from
Oct 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions Support/FoundationFileSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,8 @@ public struct FoundationFile: File {
try value.write(to: location)
}

/// Append data to the file.
///
/// Parameter value: data to be appended at the end.
public func append(_ value: Data) throws {
/// Appends the bytes in `suffix` to the file.
public func append(_ suffix: Data) throws {
let fileHandler = try FileHandle(forUpdating: location)
#if os(macOS)
// The following are needed in order to build on macOS 10.15 (Catalina). They can be removed
Expand All @@ -72,7 +70,7 @@ public struct FoundationFile: File {
fileHandler.closeFile()
#else
try fileHandler.seekToEnd()
try fileHandler.write(contentsOf: value)
try fileHandler.write(contentsOf: suffix)
try fileHandler.close()
#endif
}
Expand Down
14 changes: 10 additions & 4 deletions TrainingLoop/Callbacks/CSVLogger.swift
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@ public enum CSVLoggerError: Error {

/// A handler for logging training and validation statistics to a CSV file.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still don't know what a handler is. Still don't think this is a strong abstraction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's an observer callback where TrainingLoop is designed upon. https://docs.google.com/document/d/1CtVFhV8OcQ4E7CmNyfFeZu0IgnUPx86tfQGXiHJUXz0/edit?ts=5ebef977#heading=h.b2so9ayrnyyp

You proposed to make it a function in TrainingLoop. Here are some points I'm more in favor of making callbacks wrapped in a separate classes:

  1. Decouple from TrainingLoop
  2. Use stored properties to share callback settings
  3. Follow this pattern for all callbacks

Let's discuss more offline or in seminar meeting !?!

public class CSVLogger {
/// The path of the file that statistics are logged to.
/// The path of the file to which statistics are logged.
public var path: String

// True iff the header of the CSV file has been written.
fileprivate var headerWritten: Bool

/// Creates an instance that logs to a file with the given path.
/// Creates an instance that logs to a file with the given `path`.
///
/// Throws: File system errors.
public init(path: String = "run/log.csv") throws {
Expand All @@ -32,7 +32,7 @@ public class CSVLogger {
self.headerWritten = false
}

/// Logs the statistics for the 'loop' when 'batchEnd' event happens;
/// Logs the statistics for `loop` when a `batchEnd` event happens;
/// ignoring other events.
///
/// Throws: File system errors.
Expand All @@ -43,7 +43,6 @@ public class CSVLogger {
let batchIndex = loop.batchIndex, let batchCount = loop.batchCount,
let stats = loop.lastStatsLog
else {
// No-Op if trainingLoop doesn't set the required values for stats logging.
return
}

Expand All @@ -61,11 +60,18 @@ public class CSVLogger {
}
}

/// Writes a row of column names to the file.
///
/// Column names are "epoch", "batch" and the `name` of each element of `stats`,
/// in that order.
func writeHeader(stats: [(name: String, value: Float)]) throws {
let header = (["epoch", "batch"] + stats.lazy.map { $0.name }).joined(separator: ", ") + "\n"
try FoundationFile(path: path).append(header.data(using: .utf8)!)
}

/// Appends a row of statistics log to file with the given value `epoch` for
/// "epoch" column, `batch` for "batch" column, and `value`s of `stats` for corresponding
/// columns indicated by `stats` `name`s.
func writeDataRow(epoch: String, batch: String, stats: [(name: String, value: Float)]) throws {
let dataRow = ([epoch, batch] + stats.lazy.map { String($0.value) }).joined(separator: ", ")
+ "\n"
Expand Down
31 changes: 20 additions & 11 deletions TrainingLoop/Callbacks/ProgressPrinter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,37 @@

import Foundation

let progressBarLength = 30

/// A handler for printing the training and validation progress.
///
/// The progress includes epoch and batch index the training is currently
/// in, how many percentages of a full training/validation set has been done,
/// and metric statistics.
public class ProgressPrinter {
/// Print training or validation progress in response of the 'event'.
/// Length of the complete progress bar measured in count of `=` signs.
public var progressBarLength: Int

/// Creates an instance that prints training progress with the complete
/// progress bar to be `progressBarLength` characters long.
public init(progressBarLength: Int = 30) {
self.progressBarLength = progressBarLength
}

/// Prints training or validation progress in response of the `event`.
///
/// An example of the progress would be:
/// Epoch 1/12
/// 468/468 [==============================] - loss: 0.4819 - accuracy: 0.8513
/// 79/79 [==============================] - loss: 0.1520 - accuracy: 0.9521
public func print<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
/// 58/79 [======================>.......] - loss: 0.1520 - accuracy: 0.9521
public func printProgress<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
switch event {
case .epochStart:
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else {
// No-Op if trainingLoop doesn't set the required values for progress printing.
return
}

Swift.print("Epoch \(epochIndex + 1)/\(epochCount)")
print("Epoch \(epochIndex + 1)/\(epochCount)")
case .batchEnd:
guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else {
// No-Op if trainingLoop doesn't set the required values for progress printing.
return
}

Expand All @@ -46,15 +55,15 @@ public class ProgressPrinter {
stats = formatStats(lastStatsLog)
}

Swift.print(
print(
"\r\(batchIndex + 1)/\(batchCount) \(progressBar)\(stats)",
terminator: ""
)
fflush(stdout)
case .epochEnd:
Swift.print("")
print("")
case .validationStart:
Swift.print("")
print("")
default:
return
}
Expand Down
23 changes: 15 additions & 8 deletions TrainingLoop/Callbacks/StatisticsRecorder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,35 @@ import TensorFlow
///
/// Data produced by this handler can be used by ProgressPrinter, CVSLogger, etc.
public class StatisticsRecorder {
/// A Closure that returns if should call 'reset' on metricMeasurers.
/// A function that returns `true` iff recorder should call `reset`
/// on `metricMeasurers`.
public var shouldReset:
(
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
_ event: TrainingLoopEvent
) -> Bool

/// A Closure that returns if should call 'accumulate' on metricMeasurers.
/// A function that returns `true` iff recorder should call `accumulate`
/// on `metricMeasurers`.
public var shouldAccumulate:
(
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
_ event: TrainingLoopEvent
) -> Bool

/// A Closure that returns if should call 'compute' on metricMeasurers.
/// A function that returns `true` iff recorder should call `measure`
/// on `metricMeasurers`.
public var shouldCompute:
(
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
_ event: TrainingLoopEvent
) -> Bool

/// Instances of MetricsMeasurers.
/// Instances of MetricsMeasurers that you can reset accumulate and compute
/// statistics periodically.
fileprivate var metricMeasurers: [MetricsMeasurer]

/// Create an instance that records 'metrics' during the training loop.
/// Creates an instance that records `metrics` during the training loop.
public init(metrics: [TrainingMetrics]) {
metricMeasurers = metrics.map { $0.measurer }

Expand Down Expand Up @@ -70,9 +74,9 @@ public class StatisticsRecorder {
}
}

/// Recording statistics in response of the 'event'.
/// Records statistics in response of the `event`.
///
/// It will record the statistics into 'lastStatsLog' in the loop where other
/// It will record the statistics into lastStatsLog property in the `loop` where other
/// callbacks can consume from.
public func record<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
guard let batchIndex = loop.batchIndex,
Expand All @@ -83,7 +87,6 @@ public class StatisticsRecorder {
let output = loop.lastStepOutput,
let target = loop.lastStepTarget
else {
// No-Op if trainingLoop doesn't set the required values for stats recording.
return
}

Expand All @@ -101,18 +104,22 @@ public class StatisticsRecorder {
}
}

/// Resets each of the metricMeasurers.
func resetMetricMeasurers() {
for index in metricMeasurers.indices {
metricMeasurers[index].reset()
}
}

/// Lets each of the metricMeasurers accumulate data from
/// `loss`, `predictions`, `labels`.
func accumulateMetrics<Output, Target>(loss: Tensor<Float>, predictions: Output, labels: Target) {
for index in metricMeasurers.indices {
metricMeasurers[index].accumulate(loss: loss, predictions: predictions, labels: labels)
}
}

/// Lets each of the metricMeasurers compute metrics on cumulated data.
func computeMetrics() -> [(String, Float)] {
var result: [(String, Float)] = []
for measurer in metricMeasurers {
Expand Down
30 changes: 28 additions & 2 deletions TrainingLoop/Metrics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,31 +24,46 @@ public enum TrainingMetrics {
}
}

/// A protocal defining functionalities of a metrics measurer.
/// An accumulator of statistics.
public protocol MetricsMeasurer {
/// Name of the metrics.
var name: String { get set }

/// Clears accumulated data up and resets measurer to initial state.
mutating func reset()

/// Accumulates data from `loss`, `predictions`, `labels`.
mutating func accumulate<Output, Target>(
loss: Tensor<Float>?, predictions: Output?, labels: Target?)
loss: Tensor<Float>?, predictions: Output?, labels: Target?
)

/// Computes metrics from cumulated data.
func measure() -> Float
}

/// A measurer for measuring loss.
public struct LossMeasurer: MetricsMeasurer {
/// Name of the LossMeasurer.
public var name: String

/// Sum of losses cumulated from batches.
private var totalBatchLoss: Float = 0

/// Count of batchs cumulated so far.
private var batchCount: Int32 = 0

/// Creates an instance with the LossMeasurer named `name`.
public init(_ name: String = "loss") {
self.name = name
}

/// Resets totalBatchLoss and batchCount to zero.
public mutating func reset() {
totalBatchLoss = 0
batchCount = 0
}

/// Adds `loss` to totalBatchLoss and increases batchCount by one.
public mutating func accumulate<Output, Target>(
loss: Tensor<Float>?, predictions: Output?, labels: Target?
) {
Expand All @@ -58,27 +73,37 @@ public struct LossMeasurer: MetricsMeasurer {
}
}

/// Computes averaged loss.
public func measure() -> Float {
return totalBatchLoss / Float(batchCount)
}
}

/// A measurer for measuring accuracy
public struct AccuracyMeasurer: MetricsMeasurer {
/// Name of the AccuracyMeasurer.
public var name: String

/// Count of correct guesses.
private var correctGuessCount: Int32 = 0

/// Count of total guesses.
private var totalGuessCount: Int32 = 0

/// Creates an instance with the AccuracyMeasurer named `name`.
public init(_ name: String = "accuracy") {
self.name = name
}

/// Resets correctGuessCount and totalGuessCount to zero.
public mutating func reset() {
correctGuessCount = 0
totalGuessCount = 0
}

/// Computes correct guess count from `loss`, `predictions` and `labels`
/// and adds it to correctGuessCount; Computes total guess count from
/// `labels` shape and adds it to totalGuessCount.
public mutating func accumulate<Output, Target>(
loss: Tensor<Float>?, predictions: Output?, labels: Target?
) {
Expand All @@ -93,6 +118,7 @@ public struct AccuracyMeasurer: MetricsMeasurer {
totalGuessCount += Int32(labels.shape.reduce(1, *))
}

/// Computes accuracy as percentage of correct guesses.
public func measure() -> Float {
return Float(correctGuessCount) / Float(totalGuessCount)
}
Expand Down
9 changes: 5 additions & 4 deletions TrainingLoop/TrainingLoop.swift
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public protocol TrainingLoopProtocol {
/// The loss function.
var lossFunction: LossFunction { get set }

/// The metrics
/// The metrics on which training is measured.
var metrics: [TrainingMetrics] { get set }

// Callbacks
Expand Down Expand Up @@ -220,14 +220,15 @@ where

/// Callbacks

// MARK: - The callbacks used to customize the training loop.

/// The callbacks used to customize the training loop.
public var callbacks: [TrainingLoopCallback<Self>]

// MARK: - Default callback objects

/// The callback that records the training statistics.
public var statisticsRecorder: StatisticsRecorder? = nil

/// The callback that prints the training progress.
public var progressPrinter: ProgressPrinter? = nil

/// Temporary data
Expand Down Expand Up @@ -292,7 +293,7 @@ where
self.progressPrinter = progressPrinter
self.callbacks = [
statisticsRecorder.record,
progressPrinter.print,
progressPrinter.printProgress,
] + callbacks
} else {
self.callbacks = callbacks
Expand Down