-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathopen_lm.patch
507 lines (464 loc) · 23.5 KB
/
open_lm.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
diff --git a/open_lm/data.py b/open_lm/data.py
index 107ff6e..05f6d08 100644
--- a/open_lm/data.py
+++ b/open_lm/data.py
@@ -38,6 +38,34 @@ from webdataset.tariterators import (
)
from webdataset.mix import RandomMix
+class MyRandomMix(RandomMix):
+ def __init__(self, datasets, probs=None, longest=False, seed=42):
+ super().__init__(datasets, probs=probs, longest=longest)
+ self.rng = random.Random()
+ self.rng.seed(seed)
+
+ def __iter__(self):
+ """Return an iterator over the sources."""
+ sources = [iter(d) for d in self.datasets]
+ return self.random_samples(sources, self.probs)
+
+ def random_samples(self, sources, probs=None):
+ if probs is None:
+ probs = [1] * len(sources)
+ else:
+ probs = list(probs)
+ while len(sources) > 0:
+ cum = (np.array(probs) / np.sum(probs)).cumsum()
+ r = self.rng.random()
+ i = np.searchsorted(cum, r)
+ try:
+ yield next(sources[i])
+ except StopIteration:
+ if self.longest:
+ del sources[i]
+ del probs[i]
+ else:
+ break
def seed_worker(worker_id):
worker_seed = torch.initial_seed() % 2**32
@@ -344,7 +372,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
all_num_samples = []
shared_epoch = SharedEpoch(epoch=epoch) # create a shared epoch store to sync epoch to dataloader worker proc
- for ii, input_shards in enumerate(input_shards_):
+ for ii, (input_shards, batch_mult) in enumerate(zip(input_shards_, args.dataset_batch_mult)):
resampled = getattr(args, "dataset_resampled", False) and is_train
num_shards = None
if is_train:
@@ -421,7 +449,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
)
map_handler = {"handler": log_and_continue} if args.ignore_parse_errors else {}
- batch_size = args.per_gpu_batch_size if is_train else args.per_gpu_val_batch_size
+ batch_size = int(batch_mult * args.per_gpu_batch_size) if is_train else args.per_gpu_val_batch_size
if data_key == "json" or data_key == "json.gz":
pipeline.extend(
@@ -430,7 +458,6 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
wds.rename(json=data_key),
wds.map_dict(json=partial(preprocess_json, vocab_size=args.vocab_size), **map_handler),
wds.to_tuple("json", **map_handler),
- wds.select(partial(filter_lt_seqlen, args.seq_len)),
wds.batched(batch_size, partial=not is_train),
]
)
@@ -439,7 +466,6 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
[
wds.map_dict(txt=partial(preprocess_txt, vocab_size=args.vocab_size), **map_handler),
wds.to_tuple("txt", **map_handler),
- wds.select(partial(filter_lt_seqlen, args.seq_len)),
wds.batched(batch_size, partial=not is_train),
]
)
@@ -451,8 +477,8 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
all_num_samples.append(num_samples)
if is_train:
- # TODO: why did we previoulsy wrap with RandomMix_
- dataset = RandomMix(datasets, probs=args.train_data_mix_weights, longest=True)
+ # Use our RandomMix with determined random seed to make sure all nodes choose the same bucket.
+ dataset = MyRandomMix(datasets, probs=args.train_data_mix_weights, longest=True, seed=args.seed)
if len(datasets) > 1:
logging.warning("Source mixing is happening during training. It is preferred to mix during tokenization.")
else:
@@ -461,17 +487,18 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
# dataset = datasets[0]
if is_train:
if not resampled:
- num_shards = num_shards or len(expand_urls(input_shards)[0])
- if num_shards < args.workers * args.world_size:
+ shards_per_source_avail = [len(expand_urls(shard_string)[0]) for shard_string in input_shards_]
+ print(f"Number of shards available from each source = {shards_per_source_avail}")
+ min_num_shards = min(shards_per_source_avail)
+ if min_num_shards < args.workers * args.world_size:
print("Please increase --train-num-samples or decrease workers or world size")
- print(f"num_shards: {num_shards}, workers: {args.workers}, world_size: {args.world_size}")
- assert num_shards >= args.workers * args.world_size, "number of shards must be >= total workers"
- # roll over and repeat a few samples to get same number of full batches on each node
+ print(f"min num_shards: {min_num_shards}, workers: {args.workers}, world_size: {args.world_size}")
+ assert min_num_shards >= args.workers * args.world_size, "number of shards must be >= total workers"
round_fn = math.floor if floor else math.ceil
- global_batch_size = batch_size * args.world_size
total_num_batches = 0
total_num_samples = 0
- for ii in range(len(datasets)):
+ for ii, batch_mult in enumerate(args.dataset_batch_mult):
+ global_batch_size = int(batch_mult * args.global_batch_size)
# Calculate batches per worker, round as little as possible.
num_workers_per_gpu = max(1, args.workers)
num_worker_batches = round_fn(all_num_samples[ii] / (global_batch_size * num_workers_per_gpu))
@@ -484,7 +511,7 @@ def get_wds_dataset(args, is_train, epoch=0, floor=True, tokenizer=None, data_ke
)
num_batches = num_worker_batches * num_workers_per_gpu
- num_samples = num_batches * global_batch_size
+ num_samples = num_batches * args.global_batch_size # Number of sequences as if all were the longest (8k)
# This forces the dataloader to take num_worker_batches steps per worker, so num_batches total.
datasets[ii] = datasets[ii].repeat(nepochs=1, nbatches=num_worker_batches)
@@ -704,18 +731,6 @@ def mask_sequence(chunk, start_idx, args, ignore_tok=-100):
def sample_chunk(chunk, args):
- if chunk.shape[1] == args.seq_len + 1:
- start_idx = 0
- elif chunk.shape[1] > args.seq_len + 1:
- start_idx = torch.randint(0, chunk.shape[1] - args.seq_len, (1,)).item()
- else:
- raise Exception(f"Invalid sequence length: Sequence length {args.seq_len} > {chunk.shape[1]} Chunk size")
-
- inputs = chunk[:, start_idx : start_idx + args.seq_len]
- targets = chunk[:, start_idx + 1 : start_idx + args.seq_len + 1]
-
- # replace elements to be masked with with -100 (pytorch default xent ignore value)
- if args.target_mask_left is not None or args.target_mask_individual is not None:
- inputs, targets = mask_sequence(chunk, start_idx, args)
-
+ inputs = chunk[:, :-1]
+ targets = chunk[:, 1:]
return inputs, targets
diff --git a/open_lm/file_utils.py b/open_lm/file_utils.py
index f91919b..fe729fa 100644
--- a/open_lm/file_utils.py
+++ b/open_lm/file_utils.py
@@ -134,14 +134,22 @@ def check_exists(file_path):
return True
-def get_metadata_file(path, shard_shuffle_seed=None):
+def get_metadata_file(path, shard_shuffle_seed=None, append_a_copy=4):
of = fsspec.open(path, "rb")
with of as f:
out = f.read()
out = [json.loads(o) for o in out.decode("utf-8").split("\n")[:-1]]
+ if append_a_copy > 0:
+ out_copy = [copy.deepcopy(out) for _ in range(append_a_copy)]
if shard_shuffle_seed is not None:
rng_gen = np.random.default_rng(shard_shuffle_seed)
rng_gen.shuffle(out)
+ if append_a_copy > 0:
+ for a_copy in out_copy:
+ rng_gen.shuffle(a_copy)
+ if append_a_copy > 0:
+ for a_copy in out_copy:
+ out = out + a_copy
return out
@@ -218,7 +226,7 @@ def count_small_shards(path, ratio=0.9):
shard_sizes = np.array(shard_sizes)
- return np.sum(shard_sizes < ratio * max(shard_sizes))
+ return np.sum(shard_sizes < ratio * max(shard_sizes)), max(shard_sizes)
def are_sources_imbalanced_with_each_other(paths, ratio=2):
@@ -262,9 +270,11 @@ def log_num_checkpoints(total_steps, args):
args.world_size,
multi_epoch=args.multiple_data_passes,
shard_shuffle_seed=args.shard_shuffle_seed,
+ source_num_seq_per_epoch=args.source_num_seq_per_epoch,
)
steps_epoch = sum(
- [(n // (args.workers * args.global_batch_size)) * args.workers for n in num_samples_per_source]
+ [(n // (args.workers * args.global_batch_size * batch_mult)) * args.workers for n, batch_mult in
+ zip(num_samples_per_source, args.dataset_batch_mult)]
)
steps_done += steps_epoch
if steps_done > total_steps:
@@ -300,15 +310,18 @@ def get_string_for_epoch(
world_size: int,
multi_epoch=False,
shard_shuffle_seed=None,
+ source_num_seq_per_epoch=None,
):
"""See _single_epoch_string for full docstring."""
if multi_epoch:
return _multi_epoch_string(
- num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed
+ num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed,
+ source_num_seq_per_epoch
)
else:
return _single_epoch_string(
- num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed
+ num_samples, starting_points, paths, weights, num_workers_per_gpu, world_size, shard_shuffle_seed,
+ source_num_seq_per_epoch
)
@@ -370,6 +383,7 @@ def _single_epoch_string(
num_workers_per_gpu: int,
world_size: int,
shard_shuffle_seed: Optional[int],
+ source_num_seq_per_epoch: Optional[List[int]] = None,
):
"""Retrieve shards to train on for a particular checkpoint.
@@ -383,38 +397,25 @@ def _single_epoch_string(
num_workers_per_gpu: Number of workers per gpu process.
world_size: Total number of gpus used for training.
shard_shuffle_seed: Seed to shuffle shards before checkpoint assignment
+ source_num_seq_per_epoch: List of number of sequences per bucket per epoch.
"""
num_sources = len(paths)
-
- if num_sources > 1:
- logging.warning(
- "Multiple sources are not supported fully as of now. It is advised to combine the data into a single "
- "source, by using datapreprocess/ray/tokenize_shuffle.py. Best effort will be done to mix data at the "
- "desired ratio."
- )
- if are_sources_imbalanced_with_each_other(paths):
- logging.warning(
- "Sources contain highly imbalanced shards (largest median shard size of a source is >2x the smallest "
- "median size of a source). This will lead to deteriorated performance (less frequent checkpoints, "
- "data being skipped, and inaccurate mixing). It is STRONGLY advised to combine into one source."
- )
+ expected_num_sequence_per_shard = []
for path in paths:
- num_small_shards = count_small_shards(path)
- if num_small_shards > 0:
- logging.warning(
- f"Source defined by {path} contains {num_small_shards} shards that are smaller than 90% the size of "
- f"the largest shard. These shards might cause deterioration in performance, with more samples being "
- f"skipped than necessary. It is advised to make the shards more uniform."
- )
+ num_small_shards, expected_num_seq = count_small_shards(path)
+ expected_num_sequence_per_shard.append(expected_num_seq)
if weights is None:
weights = [1.0 / num_sources for _ in range(num_sources)]
assert len(weights) == num_sources, "One weight is needed per source."
- needed_samples_per_source = [int(np.ceil(weights[i] * num_samples / sum(weights))) for i in range(num_sources)]
+ if source_num_seq_per_epoch is None:
+ needed_samples_per_source = [int(np.ceil(weights[i] * num_samples / sum(weights))) for i in range(num_sources)]
+ else:
+ needed_samples_per_source = source_num_seq_per_epoch
manifests = [get_metadata_file(path, shard_shuffle_seed=shard_shuffle_seed) for path in paths]
@@ -424,32 +425,38 @@ def _single_epoch_string(
num_samples_per_source = [[] for _ in range(num_sources)]
total_num_workers = num_workers_per_gpu * world_size
- while not enough_shards(shard_list_per_source, total_num_workers) or not enough_samples(
- num_samples_per_source, needed_samples_per_source
- ):
- try:
- for i in range(num_sources):
+ try:
+ for i in range(num_sources):
+ while len(shard_list_per_source[i]) < total_num_workers or sum(num_samples_per_source[i]) < \
+ needed_samples_per_source[i]:
# Add shards incrementally
shard_name = manifests[i][next_shard_per_source[i]]["shard"]
try:
num_samples_shard = manifests[i][next_shard_per_source[i]]["num_sequences"]
except KeyError:
num_samples_shard = manifests[i][next_shard_per_source[i]]["num_chunks"]
-
- shard_list_per_source[i].append(shard_name)
- num_samples_per_source[i].append(num_samples_shard)
+ if num_samples_shard == expected_num_sequence_per_shard[i]:
+ shard_list_per_source[i].append(shard_name)
+ num_samples_per_source[i].append(num_samples_shard)
+ else:
+ print(
+ f"Dropping shard = {shard_name} with {num_samples_shard} samples != {expected_num_sequence_per_shard[i]}")
next_shard_per_source[i] += 1
-
- except IndexError as e:
- logging.error(
- "Number of shards requested for a single epoch is more than the number of shards available. This means "
- "that the amount of data requested to train on is more than the dataloader can serve. This can either "
- "happen because there are not enough data to begin with, or data being skipped due to rounding errors. "
- "To alleviate the latter, consider making more uniform shards, and using less workers/GPUs. This will "
- "allow for better use of the dataset."
- )
- raise e
+ except IndexError as e:
+ print(f"For Source = {i}")
+ print(f"Need samples = {needed_samples_per_source[i]}, collected {sum(num_samples_per_source[i])}")
+ print(f"Total shards so far = {next_shard_per_source[i]}")
+ print(f"len(shard_list_per_source[i]) = {len(shard_list_per_source[i])}")
+ print(f"total_num_workers = {total_num_workers}")
+ logging.error(
+ "Number of shards requested for a single epoch is more than the number of shards available. This means "
+ "that the amount of data requested to train on is more than the dataloader can serve. This can either "
+ "happen because there are not enough data to begin with, or data being skipped due to rounding errors. "
+ "To alleviate the latter, consider making more uniform shards, and using less workers/GPUs. This will "
+ "allow for better use of the dataset."
+ )
+ raise e
for i in range(num_sources):
# Ensure the number of shards is a multiple of number of workers, so each worker has the same
@@ -458,6 +465,9 @@ def _single_epoch_string(
# This is a heuristic to minimize how much data we discard when trying to ensure each worker has
# the same number of samples. Shards tend to have similar number of samples, so an extra shard
# in a worker will likely get discarded.
+ if not len(shard_list_per_source[i]) % total_num_workers == 0:
+ print(
+ f"For source {i} number of shards = {len(shard_list_per_source[i])} is not multiple of total workers = {total_num_workers}")
num_multiples = len(shard_list_per_source[i]) // total_num_workers
shard_list_per_source[i] = shard_list_per_source[i][: num_multiples * total_num_workers]
diff --git a/open_lm/main.py b/open_lm/main.py
index 7c80f55..0da7edc 100644
--- a/open_lm/main.py
+++ b/open_lm/main.py
@@ -793,6 +793,7 @@ def main(args):
args.world_size,
multi_epoch=args.multiple_data_passes,
shard_shuffle_seed=args.shard_shuffle_seed,
+ source_num_seq_per_epoch=args.source_num_seq_per_epoch,
)
# In the distributed case, make sure that all nodes receive the same string
diff --git a/open_lm/model_configs/open_lm_160m.json b/open_lm/model_configs/open_lm_160m.json
index ea4fe6e..944faf0 100644
--- a/open_lm/model_configs/open_lm_160m.json
+++ b/open_lm/model_configs/open_lm_160m.json
@@ -2,7 +2,7 @@
"hidden_dim": 768,
"n_layers": 12,
"n_heads": 12,
- "seq_len": 2048,
+ "seq_len": 8192,
"vocab_size": 50432,
"post_embed_norm": false,
"weight_tying": false
diff --git a/open_lm/model_configs/open_lm_1b.json b/open_lm/model_configs/open_lm_1b.json
index fc1878e..774fc9b 100644
--- a/open_lm/model_configs/open_lm_1b.json
+++ b/open_lm/model_configs/open_lm_1b.json
@@ -2,7 +2,7 @@
"hidden_dim": 2048,
"n_layers": 24,
"n_heads": 16,
- "seq_len": 2048,
+ "seq_len": 8192,
"vocab_size": 50432,
"post_embed_norm": false,
"weight_tying": false
diff --git a/open_lm/model_configs/open_lm_3b.json b/open_lm/model_configs/open_lm_3b.json
index 64ec0a4..57cc24a 100644
--- a/open_lm/model_configs/open_lm_3b.json
+++ b/open_lm/model_configs/open_lm_3b.json
@@ -2,7 +2,7 @@
"hidden_dim": 2560,
"n_layers": 32,
"n_heads": 32,
- "seq_len": 2048,
+ "seq_len": 8192,
"vocab_size": 50432,
"post_embed_norm": false,
"weight_tying": false
diff --git a/open_lm/model_configs/open_lm_410m.json b/open_lm/model_configs/open_lm_410m.json
index 8532173..1010cf7 100644
--- a/open_lm/model_configs/open_lm_410m.json
+++ b/open_lm/model_configs/open_lm_410m.json
@@ -2,7 +2,7 @@
"hidden_dim": 1024,
"n_layers": 24,
"n_heads": 16,
- "seq_len": 2048,
+ "seq_len": 8192,
"vocab_size": 50432,
"post_embed_norm": false,
"weight_tying": false
diff --git a/open_lm/model_configs/open_lm_7b.json b/open_lm/model_configs/open_lm_7b.json
index e662dab..b9178d0 100644
--- a/open_lm/model_configs/open_lm_7b.json
+++ b/open_lm/model_configs/open_lm_7b.json
@@ -2,7 +2,7 @@
"hidden_dim": 4096,
"n_layers": 32,
"n_heads": 32,
- "seq_len": 2048,
+ "seq_len": 8192,
"vocab_size": 50432,
"post_embed_norm": false,
"weight_tying": false
diff --git a/open_lm/params.py b/open_lm/params.py
index 0a7a3f6..389b805 100644
--- a/open_lm/params.py
+++ b/open_lm/params.py
@@ -787,6 +787,20 @@ def parse_args(args):
default=0,
help="This is the maximum number of failed checkpoints (due to not having seen enough tokens) that are allowed",
)
+ parser.add_argument(
+ "--dataset-batch-mult",
+ type=float,
+ nargs="+",
+ default=None,
+ help="Multiplier of batchsize to be used for each dataset (with respect to base batchsize).",
+ )
+ parser.add_argument(
+ "--source-num-seq-per-epoch",
+ type=int,
+ nargs="+",
+ default=None,
+ help="Number of sequences to be used per epoch from each source.",
+ )
add_model_args(parser)
diff --git a/open_lm/positional_embedding/rotary.py b/open_lm/positional_embedding/rotary.py
index b48ed89..d5c1af0 100644
--- a/open_lm/positional_embedding/rotary.py
+++ b/open_lm/positional_embedding/rotary.py
@@ -57,7 +57,7 @@ class RotaryEmbedding(torch.nn.Module):
self.reset_parameters()
def reset_parameters(self):
- self.inv_freq = 1.0 / (10000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model))
+ self.inv_freq = 1.0 / (100000 ** (torch.arange(0, self.dim_model, 2).float() / self.dim_model))
self._update_cos_sin_tables(self.seq_len)
def _update_cos_sin_tables(self, seq_len: int = None, device: torch.device = None, dtype: torch.dtype = None):
diff --git a/open_lm/train.py b/open_lm/train.py
index 0d54bf7..eccf708 100644
--- a/open_lm/train.py
+++ b/open_lm/train.py
@@ -110,13 +110,17 @@ def train_one_epoch(
try:
batch = next(data_iterator)
- has_data = torch.tensor(1, dtype=torch.long, device=device)
+ has_data = torch.tensor([1, len(batch[0])], dtype=torch.long, device=device)
except StopIteration:
- has_data = torch.tensor(0, dtype=torch.long, device=device)
+ logging.warning("Could not get a batch!!!")
+ has_data = torch.tensor([0, 0], dtype=torch.long, device=device)
if args.world_size > 1:
dist.all_reduce(has_data, op=ReduceOp.SUM)
- if has_data < args.world_size:
+ if has_data[1] != len(batch[0]) * args.world_size:
+ logging.warning("Same global sequence length consistency broke! This can reduce performance.")
+ if has_data[0] != args.world_size:
+ logging.warning("At least one gpu could not get a batch.")
break
(texts,) = batch
@@ -153,12 +157,12 @@ def train_one_epoch(
# save the loss for the average model for logging
total_loss_avg[key] = loss(out_avg.reshape(-1, args.vocab_size), targets.reshape(-1))
else:
+ inputs, targets = sample_chunk(texts, args)
+
# split up batch into accum_freq chunks -- if you have --batch-size 8 and --accum-freq 4
# then you only process 2 items at a time. batch-size must be divisible by accume-freq.
- assert args.per_gpu_batch_size % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq"
- per_batch = args.per_gpu_batch_size // args.accum_freq
-
- inputs, targets = sample_chunk(texts, args)
+ assert inputs.shape[0] % args.accum_freq == 0, "Per-GPU batch size must be divisible by accum_freq"
+ per_batch = inputs.shape[0] // args.accum_freq
forward_total_time = 0
backward_total_time = 0
@@ -291,7 +295,7 @@ def train_one_epoch(
for key, value in total_loss_avg.items():
losses_avg_m[key].update(value.item(), batch_size)
if i % args.log_every_n_steps == 0 or batch_count == num_batches_per_epoch or step == total_steps - 1:
- num_samples = batch_count * batch_size * args.world_size
+ num_samples = batch_count * args.global_batch_size # Number of sequences seen as if all were the longest
samples_per_epoch = dataloader.num_samples
percent_complete = 100.0 * batch_count / num_batches_per_epoch
@@ -332,6 +336,7 @@ def train_one_epoch(
"tokens": (step + 1) * args.global_batch_size * args.seq_len,
"expected_steps_epoch": data["train"].dataloader.num_batches,
"seen_steps_epoch": batch_count,
+ "seq_len": inputs.shape[1],
}
if averagers is not None and args.log_avg_model_training_loss: