-
Notifications
You must be signed in to change notification settings - Fork 149
/
Copy pathmain.swift
149 lines (134 loc) · 5.5 KB
/
main.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Copyright 2020 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
import Datasets
import Foundation
import ModelSupport
import TensorFlow
import TextModels
import TrainingLoop
import x10_optimizers_optimizer
let device = Device.defaultXLA
var bertPretrained: BERT.PreTrainedModel
if CommandLine.arguments.count >= 2 {
if CommandLine.arguments[1].lowercased() == "albert" {
bertPretrained = BERT.PreTrainedModel.albertBase
} else if CommandLine.arguments[1].lowercased() == "roberta" {
bertPretrained = BERT.PreTrainedModel.robertaBase
} else if CommandLine.arguments[1].lowercased() == "electra" {
bertPretrained = BERT.PreTrainedModel.electraBase
} else {
bertPretrained = BERT.PreTrainedModel.bertBase(cased: false, multilingual: false)
}
} else {
bertPretrained = BERT.PreTrainedModel.bertBase(cased: false, multilingual: false)
}
let bert = try bertPretrained.load()
var bertClassifier = BERTClassifier(bert: bert, classCount: 1)
bertClassifier.move(to: device)
// Regarding the batch size, note that the way batching is performed currently is that we bucket
// input sequences based on their length (e.g., first bucket contains sequences of length 1 to 10,
// second 11 to 20, etc.). We then keep processing examples in the input data pipeline until a
// bucket contains enough sequences to form a batch. The batch size specified in the task
// constructor specifies the *total number of tokens in the batch* and not the total number of
// sequences. So, if the batch size is set to 1024, the first bucket (i.e., lengths 1 to 10)
// will need 1024 / 10 = 102 examples to form a batch (every sentence in the bucket is padded
// to the max length of the bucket). This kind of bucketing is common practice with NLP models and
// it is done to improve memory usage and computational efficiency when dealing with sequences of
// varied lengths. Note that this is not used in the original BERT implementation released by
// Google and so the batch size setting here is expected to differ from that one.
let maxSequenceLength = 128
let batchSize = 1024
let epochCount = 3
let stepsPerEpoch = 1068 // function of training set size and batching configuration
let peakLearningRate: Float = 2e-5
let workspaceURL = URL(
fileURLWithPath: "bert_models", isDirectory: true,
relativeTo: URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true))
var cola = try CoLA(
taskDirectoryURL: workspaceURL,
maxSequenceLength: maxSequenceLength,
batchSize: batchSize,
entropy: SystemRandomNumberGenerator(),
on: device
) { example in
// In this closure, both the input and output text batches must be eager
// since the text is not padded and x10 requires stable shapes.
let classifier = bertClassifier
let textBatch = classifier.bert.preprocess(
sequences: [example.sentence],
maxSequenceLength: maxSequenceLength)
return LabeledData(data: textBatch, label: Tensor<Int32>(example.isAcceptable! ? 1 : 0))
}
print("Dataset acquired.")
let beta1: Float = 0.9
let beta2: Float = 0.999
let useBiasCorrection = true
var optimizer = x10_optimizers_optimizer.GeneralOptimizer(
for: bertClassifier,
TensorVisitorPlan(bertClassifier.differentiableVectorView),
defaultOptimizer: makeAdam(
learningRate: peakLearningRate,
beta1: beta1,
beta2: beta2
)
)
/// Computes sigmoidCrossEntropy loss from `logits` and `labels`.
///
/// This defines the loss function used in TrainingLoop; it's a wrapper of the
/// standard sigmoidCrossEntropy; it reshapes logits to required shape before
/// calling the standard sigmoidCrossEntropy.
@differentiable
public func sigmoidCrossEntropyReshaped<Scalar>(logits: Tensor<Scalar>, labels: Tensor<Int32>)
-> Tensor<
Scalar
> where Scalar: TensorFlowFloatingPoint
{
return sigmoidCrossEntropy(
logits: logits.squeezingShape(at: -1),
labels: Tensor<Scalar>(labels),
reduction: _mean)
}
/// Clips the gradients by global norm.
///
/// This's defined as a callback registered into TrainingLoop.
func clipGradByGlobalNorm<L: TrainingLoopProtocol>(_ loop: inout L, event: TrainingLoopEvent) throws
{
if event == .updateStart {
var gradients = loop.lastStepGradient!
gradients.clipByGlobalNorm(clipNorm: 1)
loop.lastStepGradient = gradients
}
}
/// A linear shape to the learning rate in both warmup and decay phases.
let linear = Shape({ $0 })
var trainingLoop: TrainingLoop = TrainingLoop(
training: cola.trainingEpochs,
validation: cola.validationBatches,
optimizer: optimizer,
lossFunction: sigmoidCrossEntropyReshaped,
metrics: [.matthewsCorrelationCoefficient],
callbacks: [
clipGradByGlobalNorm,
learningRateScheduler(
schedule: makeSchedule(
[
ScheduleSegment(shape: linear, startRate: 0, endRate: peakLearningRate, stepCount: 10),
ScheduleSegment(shape: linear, endRate: 0)
]
),
biasCorrectionBeta: (beta1, beta2)
),
])
print("Training \(bertPretrained.name) for the CoLA task!")
try! trainingLoop.fit(&bertClassifier, epochs: epochCount, on: device)