Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
xihui-wu committed Sep 15, 2020
1 parent 565fff8 commit 88cf2cb
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 144 deletions.
2 changes: 1 addition & 1 deletion TrainingLoop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ add_library(TrainingLoop
Metrics.swift
TrainingLoop.swift
Callbacks/StatisticsRecorder.swift
Callbacks/ProgressPrinter.swift,
Callbacks/ProgressPrinter.swift
Callbacks/CSVLogger.swift)
target_link_libraries(TrainingLoop PUBLIC
ModelSupport)
Expand Down
131 changes: 67 additions & 64 deletions TrainingLoop/Callbacks/CVSLogger.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,74 +3,77 @@ import ModelSupport

/// A callback-based handler for logging the statistics to CVS file.
public class CVSLogger {
public var path: String
public var liveStatistics: Bool

var foundationFS: FoundationFileSystem
var foundationFile: FoundationFile
public var path: String
public var liveStatistics: Bool

/// Create an instance that log statistics during the training loop.
///
/// - Parameters:
/// - liveStatistics: whether or not log the statistics lively on each batch.
public init(withPath path: String = "run/log.csv", liveStatistics: Bool = true) {
self.path = path
self.liveStatistics = liveStatistics
self.foundationFS = FoundationFileSystem()
self.foundationFile = FoundationFile(path: path)
}
let foundationFS: FoundationFileSystem
let foundationFile: FoundationFile

/// The callback used to hook into the TrainingLoop for logging statistics.
///
/// - Parameters:
/// - loop: The TrainingLoop where an event has occurred.
/// - event: The training or validation event that this callback is responding to.
public func log<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
switch event {
case .batchEnd:
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount,
let batchIndex = loop.batchIndex, let batchCount = loop.batchCount else {
break
}
/// Create an instance that log statistics during the training loop.
///
/// - Parameters:
/// - liveStatistics: whether or not log the statistics lively on each batch.
public init(withPath path: String = "run/log.csv", liveStatistics: Bool = true) {
self.path = path
self.liveStatistics = liveStatistics
self.foundationFS = FoundationFileSystem()
self.foundationFile = FoundationFile(path: path)
}

if !liveStatistics && (batchIndex + 1 != batchCount) {
break
}
/// The callback used to hook into the TrainingLoop for logging statistics.
///
/// - Parameters:
/// - loop: The TrainingLoop where an event has occurred.
/// - event: The training or validation event that this callback is responding to.
public func log<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws {
switch event {
case .batchEnd:
guard let epochIndex = loop.epochIndex, let epochCount = loop.epochCount,
let batchIndex = loop.batchIndex, let batchCount = loop.batchCount
else {
break
}

guard let stats = loop.lastStatsLog else {
break
}
if !liveStatistics && (batchIndex + 1 != batchCount) {
break
}

if !FileManager.default.fileExists(atPath: path) {
try foundationFS.createDirectoryIfMissing(at: String(path[..<path.lastIndex(of: "/")!]))
try writeHeader(stats: stats)
}
try writeDataRow(
epoch: "\(epochIndex + 1)/\(epochCount)",
batch: "\(batchIndex + 1)/\(batchCount)",
stats: stats)
default:
break
}
}
guard let stats = loop.lastStatsLog else {
break
}

func writeHeader(stats: [(String, Float)]) throws {
let head: String = (["epoch", "batch"] + stats.map { $0.0 }).joined(separator: ", ")
do {
try head.write(toFile: path, atomically: true, encoding: .utf8)
} catch {
print("Unexpected error in writing header line: \(error).")
throw error
}
}
if !FileManager.default.fileExists(atPath: path) {
try foundationFS.createDirectoryIfMissing(at: String(path[..<path.lastIndex(of: "/")!]))
try writeHeader(stats: stats)
}
try writeDataRow(
epoch: "\(epochIndex + 1)/\(epochCount)",
batch: "\(batchIndex + 1)/\(batchCount)",
stats: stats)
default:
break
}
}

func writeDataRow(epoch: String, batch: String, stats: [(String, Float)]) throws {
let dataRow: Data = ("\n" + ([epoch, batch] + stats.map { String($0.1) }).joined(separator: ", ")).data(using: .utf8)!
do {
try foundationFile.append(dataRow)
} catch {
print("Unexpected error in writing data row: \(error).")
throw error
}
}
}
func writeHeader(stats: [(String, Float)]) throws {
let head: String = (["epoch", "batch"] + stats.map { $0.0 }).joined(separator: ", ")
do {
try head.write(toFile: path, atomically: true, encoding: .utf8)
} catch {
print("Unexpected error in writing header line: \(error).")
throw error
}
}

func writeDataRow(epoch: String, batch: String, stats: [(String, Float)]) throws {
let dataRow: Data = (
"\n" + ([epoch, batch] + stats.map { String($0.1) }).joined(separator: ", ")
).data(using: .utf8)!
do {
try foundationFile.append(dataRow)
} catch {
print("Unexpected error in writing data row: \(error).")
throw error
}
}
}
2 changes: 1 addition & 1 deletion TrainingLoop/Callbacks/ProgressPrinter.swift
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public class ProgressPrinter {
Swift.print("")
case .validationStart:
Swift.print("")
default:
default:
return
}
}
Expand Down
8 changes: 5 additions & 3 deletions TrainingLoop/Callbacks/StatisticsRecorder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,12 @@ public class StatisticsRecorder {
case .trainingStart, .validationStart:
resetMetricMeasurers()
case .batchEnd:
if let loss = loop.lastStepLoss, let output = loop.lastStepOutput, let target = loop.lastStepTarget {
if let loss = loop.lastStepLoss, let output = loop.lastStepOutput,
let target = loop.lastStepTarget
{
accumulateMetrics(loss: loss, predictions: output, labels: target)
}

if let batchIndex = loop.batchIndex, let batchCount = loop.batchCount {
if liveStatistics || (batchCount == (batchIndex + 1)) {
loop.lastStatsLog = computeMetrics()
Expand All @@ -71,7 +73,7 @@ public class StatisticsRecorder {

func accumulateMetrics<Output, Target>(loss: Tensor<Float>, predictions: Output, labels: Target) {
for index in metricMeasurers.indices {
metricMeasurers[index].accumulate(loss: loss, predictions: predictions, labels: labels)
metricMeasurers[index].accumulate(loss: loss, predictions: predictions, labels: labels)
}
}

Expand Down
141 changes: 75 additions & 66 deletions TrainingLoop/Metrics.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,86 +6,95 @@ public enum TrainingMetrics {
case accuracy

public var name: String {
switch self {
case .loss:
return "loss"
case .accuracy:
return "accuracy"
}
switch self {
case .loss:
return "loss"
case .accuracy:
return "accuracy"
}
}

public var measurer: MetricsMeasurer {
switch self {
case .loss:
return LossMeasurer(self.name)
case .accuracy:
return AccuracyMeasurer(self.name)
}
switch self {
case .loss:
return LossMeasurer(self.name)
case .accuracy:
return AccuracyMeasurer(self.name)
}
}
}

/// A protocal defining functionalities of a metrics measurer.
public protocol MetricsMeasurer {
var name: String { get set }
mutating func reset()
mutating func accumulate<Output, Target>(loss: Tensor<Float>?, predictions: Output?, labels: Target?)
func measure() -> Float
var name: String { get set }
mutating func reset()
mutating func accumulate<Output, Target>(
loss: Tensor<Float>?, predictions: Output?, labels: Target?
)
func measure() -> Float
}

/// A measurer for measuring loss.
public struct LossMeasurer: MetricsMeasurer {
public var name: String

private var totalBatchLoss: Float = 0
private var batchCount: Int32 = 0

public init(_ name: String = "loss") {
self.name = name
}

public mutating func reset() {
totalBatchLoss = 0
batchCount = 0
}

public mutating func accumulate<Output, Target>(loss: Tensor<Float>?, predictions: Output?, labels: Target?) {
if let newBatchLoss = loss {
totalBatchLoss += newBatchLoss.scalarized()
batchCount += 1
}
}

public func measure() -> Float {
return totalBatchLoss / Float(batchCount)
}
public var name: String

private var totalBatchLoss: Float = 0
private var batchCount: Int32 = 0

public init(_ name: String = "loss") {
self.name = name
}

public mutating func reset() {
totalBatchLoss = 0
batchCount = 0
}

public mutating func accumulate<Output, Target>(
loss: Tensor<Float>?, predictions: Output?, labels: Target?
) {
if let newBatchLoss = loss {
totalBatchLoss += newBatchLoss.scalarized()
batchCount += 1
}
}

public func measure() -> Float {
return totalBatchLoss / Float(batchCount)
}
}

/// A measurer for measuring accuracy
public struct AccuracyMeasurer: MetricsMeasurer {
public var name: String

private var correctGuessCount: Int32 = 0
private var totalGuessCount: Int32 = 0

public init(_ name: String = "accuracy") {
self.name = name
}

public mutating func reset() {
correctGuessCount = 0
totalGuessCount = 0
}

public mutating func accumulate<Output, Target>(loss: Tensor<Float>?, predictions: Output?, labels: Target?) {
guard let predictions = predictions as? Tensor<Float>, let labels = labels as? Tensor<Int32> else {
fatalError(
"For accuracy measurements, the model output must be Tensor<Float>, and the labels must be Tensor<Int>.")
}
correctGuessCount += Tensor<Int32>(predictions.argmax(squeezingAxis: 1) .== labels).sum().scalarized()
totalGuessCount += Int32(labels.shape[0])
}

public func measure() -> Float {
return Float(correctGuessCount) / Float(totalGuessCount)
}
public var name: String

private var correctGuessCount: Int32 = 0
private var totalGuessCount: Int32 = 0

public init(_ name: String = "accuracy") {
self.name = name
}

public mutating func reset() {
correctGuessCount = 0
totalGuessCount = 0
}

public mutating func accumulate<Output, Target>(
loss: Tensor<Float>?, predictions: Output?, labels: Target?
) {
guard let predictions = predictions as? Tensor<Float>, let labels = labels as? Tensor<Int32>
else {
fatalError(
"For accuracy measurements, the model output must be Tensor<Float>, and the labels must be Tensor<Int>."
)
}
correctGuessCount += Tensor<Int32>(predictions.argmax(squeezingAxis: 1) .== labels).sum()
.scalarized()
totalGuessCount += Int32(labels.shape[0])
}

public func measure() -> Float {
return Float(correctGuessCount) / Float(totalGuessCount)
}
}
Loading

0 comments on commit 88cf2cb

Please sign in to comment.