Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TrainingLoop: refactor progress printer and add CSVLogger #668

Merged
merged 5 commits into from
Sep 24, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions Examples/LeNet-MNIST/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -31,24 +31,33 @@ let dataset = MNIST(batchSize: batchSize, on: device)

// The LeNet-5 model, equivalent to `LeNet` in `ImageClassificationModels`.
var classifier = Sequential {
Conv2D<Float>(filterShape: (5, 5, 1, 6), padding: .same, activation: relu)
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
Conv2D<Float>(filterShape: (5, 5, 6, 16), activation: relu)
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
Flatten<Float>()
Dense<Float>(inputSize: 400, outputSize: 120, activation: relu)
Dense<Float>(inputSize: 120, outputSize: 84, activation: relu)
Dense<Float>(inputSize: 84, outputSize: 10)
Conv2D<Float>(filterShape: (5, 5, 1, 6), padding: .same, activation: relu)
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
Conv2D<Float>(filterShape: (5, 5, 6, 16), activation: relu)
AvgPool2D<Float>(poolSize: (2, 2), strides: (2, 2))
Flatten<Float>()
Dense<Float>(inputSize: 400, outputSize: 120, activation: relu)
Dense<Float>(inputSize: 120, outputSize: 84, activation: relu)
Dense<Float>(inputSize: 84, outputSize: 10)
}

var optimizer = SGD(for: classifier, learningRate: 0.1)

let trainingProgress = TrainingProgress()
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
callbacks: [trainingProgress.update])
metrics: [.accuracy],
callbacks: [try! CSVLogger().log])

// Compute statistics only when last batch ends.
trainingLoop.statisticsRecorder!.shouldCompute = {
(
_ batchIndex: Int, _ batchCount: Int, _ epochIndex: Int, _ epochCount: Int,
_ event: TrainingLoopEvent
) -> Bool in
return event == .batchEnd && batchIndex + 1 == batchCount
}

try! trainingLoop.fit(&classifier, epochs: epochCount, on: device)
3 changes: 1 addition & 2 deletions Examples/MobileNetV1-Imagenette/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224,
var model = MobileNetV1(classCount: 10)
let optimizer = SGD(for: model, learningRate: 0.02, momentum: 0.9)

let trainingProgress = TrainingProgress()
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
callbacks: [trainingProgress.update])
metrics: [.accuracy])

try! trainingLoop.fit(&model, epochs: 10, on: device)
3 changes: 1 addition & 2 deletions Examples/MobileNetV2-Imagenette/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ let dataset = Imagenette(batchSize: 64, inputSize: .resized320, outputSize: 224,
var model = MobileNetV2(classCount: 10)
let optimizer = SGD(for: model, learningRate: 0.002, momentum: 0.9)

let trainingProgress = TrainingProgress()
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
callbacks: [trainingProgress.update])
metrics: [.accuracy])

try! trainingLoop.fit(&model, epochs: 10, on: device)
3 changes: 1 addition & 2 deletions Examples/ResNet-CIFAR10/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,11 @@ let dataset = CIFAR10(batchSize: 10, on: device)
var model = ResNet(classCount: 10, depth: .resNet56, downsamplingInFirstStage: false)
var optimizer = SGD(for: model, learningRate: 0.001)

let trainingProgress = TrainingProgress()
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
callbacks: [trainingProgress.update])
metrics: [.accuracy])

try! trainingLoop.fit(&model, epochs: 10, on: device)
4 changes: 2 additions & 2 deletions Examples/VGG-Imagewoof/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,12 @@ public func scheduleLearningRate<L: TrainingLoopProtocol>(
}
}

let trainingProgress = TrainingProgress()
var trainingLoop = TrainingLoop(
training: dataset.training,
validation: dataset.validation,
optimizer: optimizer,
lossFunction: softmaxCrossEntropy,
callbacks: [trainingProgress.update, scheduleLearningRate])
metrics: [.accuracy],
callbacks: [scheduleLearningRate])

try! trainingLoop.fit(&model, epochs: 90, on: device)
1 change: 1 addition & 0 deletions Support/FileSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,5 @@ public protocol File {
func read(position: Int, count: Int) throws -> Data
func write(_ value: Data) throws
func write(_ value: Data, position: Int) throws
func append(_ value: Data) throws
}
10 changes: 10 additions & 0 deletions Support/FoundationFileSystem.swift
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,14 @@ public struct FoundationFile: File {
// TODO: Incorporate file offset.
try value.write(to: location)
}

/// Append data to the file.
xihui-wu marked this conversation as resolved.
Show resolved Hide resolved
///
/// Parameter value: data to be appended at the end.
xihui-wu marked this conversation as resolved.
Show resolved Hide resolved
public func append(_ value: Data) throws {
xihui-wu marked this conversation as resolved.
Show resolved Hide resolved
let fileHandler = try FileHandle(forUpdating: location)
try fileHandler.seekToEnd()
try fileHandler.write(contentsOf: value)
try fileHandler.close()
}
}
6 changes: 4 additions & 2 deletions TrainingLoop/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
add_library(TrainingLoop
LossFunctions.swift
Metrics.swift
TrainingLoop.swift
TrainingProgress.swift
TrainingStatistics.swift)
Callbacks/StatisticsRecorder.swift
Callbacks/ProgressPrinter.swift
Callbacks/CSVLogger.swift)
target_link_libraries(TrainingLoop PUBLIC
ModelSupport)
set_target_properties(TrainingLoop PROPERTIES
Expand Down
68 changes: 68 additions & 0 deletions TrainingLoop/Callbacks/CSVLogger.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import Foundation
import ModelSupport

/// A handler for logging training and validation statistics to a CSV file.
public class CSVLogger {
/// The path of the file that statistics are logged to.
xihui-wu marked this conversation as resolved.
Show resolved Hide resolved
public var path: String
xihui-wu marked this conversation as resolved.
Show resolved Hide resolved

// The boolean variable indicating if header of the CSV file has been written or not.
xihui-wu marked this conversation as resolved.
Show resolved Hide resolved
fileprivate var headerWritten: Bool

/// Creates an instance that logs to a file with the given path.
///
/// Throws: File system errors.
public init(path: String = "run/log.csv") throws {
self.path = path
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems very unlikely to me that we actually need to store path in addition to foundationFile. Consider whether it can/should be dropped because you can get it from foundationFile.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same. Removed foundationFile and kept the path.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I'm not sure you want to open and close the file for every line logged, though. (I am presuming that the foundationFile object keeps the file open)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How did you “resolve” this comment?


// Create the containing directory if it is missing.
let containingDir = String(path[..<path.lastIndex(of: "/")!])
if containingDir != "" {
xihui-wu marked this conversation as resolved.
Show resolved Hide resolved
try FoundationFileSystem().createDirectoryIfMissing(at: containingDir)
}
// Initialize the file with empty string.
try FoundationFile(path: path).write(Data())

self.headerWritten = false
}

/// Logs the statistics for the 'loop' when 'batchEnd' event happens;
/// ignoring other events.
///
/// Throws: File systel errors.
public func log<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
xihui-wu marked this conversation as resolved.
Show resolved Hide resolved
switch event {
case .batchEnd:
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount,
let batchIndex = loop.batchIndex, let batchCount = loop.batchCount,
let stats = loop.lastStatsLog
else {
// No-Op if trainingLoop doesn't set the required values for stats logging.
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why would any of these things be nil, and why are we bailing out when they're nil? That should be explained in a comment. Ditto for stats below. Also, why are these two separate guard statements?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merged stats into the same guard statements.
These are designed as optionals in existing TrainingLoop. I think that's because it's not the must-have values to complete a training process. I added an inline comment.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, first, I would rather see something like

// These properties will be `nil` unless stats logging was requested

because it describes the situation at a semantic level rather than at the level of what some code did.

(also writing "No-Op" adds nothing to what is already very obvious from the code)

But that said, it seems very unlikely that the comment I'd like to see is true of any but the last property. All the others refer to values that have nothing to do with logging. So I want to know what causes epochIndex to be nil, for example.

It's a big design flaw in trainingLoop that it has so many optionals, and that makes this task more difficult, but I believe that's not the code you're working on(?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These variables were originally designed to be optionals to store temporary data in protocol: https://github.com/tensorflow/swift-models/blob/master/TrainingLoop/TrainingLoop.swift#L76

The generic TrainingLoop that implements the protocol does set all these optionals. My guess on why it was designed so is that it allows other TrainingLoops not setting them.

So the point on the comment is NOT "if stats logging doesn't request it then these properties will be nil", but "if these properties are nil No-op on the CSVLogger".

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then you should delete the comment. The code very clearly says that all by itself, so the comment explains nothing.

I don't know if you're missing the point I'm trying to make, or you just disagree with it, but I'm doing this code review 100% for your benefit as a programmer. If the review process is blocking your progress, please feel free to just commit the changes, and decide separately about whether you want the feedback I'm giving you here. If you do, we can continue to discuss it.

}

if !headerWritten {
dabrahams marked this conversation as resolved.
Show resolved Hide resolved
try writeHeader(stats: stats)
headerWritten = true
}

try writeDataRow(
xihui-wu marked this conversation as resolved.
Show resolved Hide resolved
epoch: "\(epochIndex + 1)/\(epochCount)",
batch: "\(batchIndex + 1)/\(batchCount)",
stats: stats)
default:
return
}
}

func writeHeader(stats: [(name: String, value: Float)]) throws {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doc comments are missing on this and the following method.

let header: String = (["epoch", "batch"] + stats.map { $0.name }).joined(separator: ", ") + "\n"
xihui-wu marked this conversation as resolved.
Show resolved Hide resolved
try FoundationFile(path: path).append(header.data(using: .utf8)!)
}

func writeDataRow(epoch: String, batch: String, stats: [(name: String, value: Float)]) throws {
let dataRow: String = ([epoch, batch] + stats.map { String($0.value) }).joined(separator: ", ")
+ "\n"
try FoundationFile(path: path).append(dataRow.data(using: .utf8)!)
}
}
85 changes: 85 additions & 0 deletions TrainingLoop/Callbacks/ProgressPrinter.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Foundation

let progressBarLength = 30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

length in what unit? I suppose it's probably characters. Being a top-level declaration this should have a doc comment, and that would be a perfect place to put the answer. Is this the number of = signs, or the whole length printed, or…?


/// A handler for printing the training and validation progress.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what a "handler" is. That makes me suspect this class is not really representing any abstraction.

Since it has no storage, there's no reason to make it a class rather than a struct, but I don't think it should exist as a type. This thing is just a generic function over TrainingLoopProtocol instances.

print not a great name for the method, because pp.print(myLoop, someEvent) looks like it's printing myLoop and someEvent, especially given the precedent set by Swift.print(myLoop, someEvent).

A more appropriate idiom would be:

extension TrainingLoopProtocol {
  public mutating func printProgress(event: TrainingLoopEvent) throws { ... }
}

public class ProgressPrinter {
/// Print training or validation progress in response of the 'event'.
///
/// An example of the progress would be:
/// Epoch 1/12
/// 468/468 [==============================] - loss: 0.4819 - accuracy: 0.8513
/// 79/79 [==============================] - loss: 0.1520 - accuracy: 0.9521
public func print<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
switch event {
case .epochStart:
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount else {
// No-Op if trainingLoop doesn't set the required values for progress printing.
return
}

Swift.print("Epoch \(epochIndex + 1)/\(epochCount)")
case .batchEnd:
guard let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else {
// No-Op if trainingLoop doesn't set the required values for progress printing.
return
}

let progressBar = formatProgressBar(
progress: Float(batchIndex + 1) / Float(batchCount), length: progressBarLength)
var stats: String = ""
if let lastStatsLog = loop.lastStatsLog {
stats = formatStats(lastStatsLog)
}

Swift.print(
"\r\(batchIndex + 1)/\(batchCount) \(progressBar)\(stats)",
terminator: ""
)
fflush(stdout)
case .epochEnd:
Swift.print("")
case .validationStart:
Swift.print("")
default:
return
}
}

func formatProgressBar(progress: Float, length: Int) -> String {
let progressSteps = Int(round(Float(length) * progress))
let leading = String(repeating: "=", count: progressSteps)
let separator: String
let trailing: String
if progressSteps < progressBarLength {
separator = ">"
trailing = String(repeating: ".", count: progressBarLength - progressSteps - 1)
} else {
separator = ""
trailing = ""
}
return "[\(leading)\(separator)\(trailing)]"
}

func formatStats(_ stats: [(String, Float)]) -> String {
var result = ""
for stat in stats {
result += " - \(stat.0): \(String(format: "%.4f", stat.1))"
}
return result
}
}
Loading