-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy path0001-DeMo.patch
668 lines (649 loc) · 21.8 KB
/
0001-DeMo.patch
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
From 45d0a1c7ad6286078b2fbd274b31cfe05b990b9f Mon Sep 17 00:00:00 2001
From: redacted <[email protected]>
Date: Wed, 2 Oct 2024 01:26:11 +0000
Subject: [PATCH] DeMo
---
.gitignore | 2 +
olmo/config.py | 23 ++++
olmo/demo_utils.py | 286 +++++++++++++++++++++++++++++++++++++++++++++
olmo/optim.py | 187 ++++++++++++++++++++++++++++-
olmo/train.py | 10 +-
scripts/train.py | 5 +
6 files changed, 509 insertions(+), 4 deletions(-)
create mode 100644 olmo/demo_utils.py
diff --git a/.gitignore b/.gitignore
index 9b1e9978..68739e81 100644
--- a/.gitignore
+++ b/.gitignore
@@ -56,3 +56,5 @@ site/
/wandb/
/scratch/
core
+slurm*
+checkpoints/
\ No newline at end of file
diff --git a/olmo/config.py b/olmo/config.py
index ae454bb1..64abeb0b 100644
--- a/olmo/config.py
+++ b/olmo/config.py
@@ -498,6 +498,7 @@ class ModelConfig(BaseConfig):
class OptimizerType(StrEnum):
lionw = "lionw"
adamw = "adamw"
+ demo = "demo"
@dataclass
@@ -533,6 +534,20 @@ class OptimizerConfig(BaseConfig):
of the update with AdamW.
"""
+ ### DeMo parameters
+ compression_decay: float = 0.999
+
+ compression_topk: int = 32
+ """
+ How many numbers of topk to transmit per chunk, if dynamic is enabled, this is the initial topk
+ """
+
+ compression_chunk: int = 64
+ """
+ Size of the chunk of the gradients, note that 2D gradients are chunked in 2D, which the topk sparsity is squared compared to 1D
+ """
+
+
def __post_init__(self):
self.betas = tuple(self.betas) # type: ignore[assignment]
@@ -724,6 +739,12 @@ class DDPGradSyncMode(StrEnum):
set to True, to prevent errors.
"""
+ none = "none"
+ """
+ Totally disable gradient synchronization within the distributed model.
+ Should only be done with some explicit external synchronization (e.g. DeMo) or if you just like spinning your wheels
+ """
+
@dataclass
class DDPConfig(BaseConfig):
@@ -818,6 +839,8 @@ class FSDPConfig(BaseConfig):
PyTorch's default HSDP behavior matches this default behavior.
"""
+ disable_grad_sync: bool = False
+
class CheckpointType(StrEnum):
sharded = "sharded"
diff --git a/olmo/demo_utils.py b/olmo/demo_utils.py
new file mode 100644
index 00000000..316586ca
--- /dev/null
+++ b/olmo/demo_utils.py
@@ -0,0 +1,286 @@
+import math
+import torch
+import torch.fft
+import torch.distributed as dist
+
+from einops import rearrange
+
+
+class TransformDCT:
+ @torch.no_grad()
+ def __init__(self, param_groups, target_chunk, norm="ortho"):
+ self.target_chunk = target_chunk
+
+ self.shape_dict = dict()
+ self.f_dict = dict()
+ self.b_dict = dict()
+
+ # Get all variants of model tensor sizes
+ # Generate all possible valid DCT sizes for model tensors
+ for group in param_groups:
+ for p in group["params"]:
+ if not p.requires_grad:
+ continue
+ for s in p.shape:
+ # Get the closest smallest divisor to the targeted DCT size
+ sc = _get_smaller_split(s, self.target_chunk)
+ self.shape_dict[s] = sc
+
+ # Pregenerate DCT basis matrices
+ if sc not in self.f_dict:
+ I = torch.eye(sc)
+ self.f_dict[sc] = _dct(I, norm=norm).to(p.dtype).to(p.device)
+ self.b_dict[sc] = _idct(I, norm=norm).to(p.dtype).to(p.device)
+
+ @torch.no_grad()
+ def einsum_2d(self, x, b, d=None):
+ if d is None:
+ return torch.einsum("...ij, jb -> ...ib", x, b)
+ else:
+ # Note: b-c axis output is transposed to chunk DCT in 2D
+ return torch.einsum("...ijkl, jb, ld -> ...ikbd", x, b, d)
+
+ @torch.no_grad()
+ def einsum_2d_t(self, x, b, d=None):
+ if d is None:
+ return torch.einsum("...ij, jb -> ...ib", x, b)
+ else:
+ # Note: b-c axis output is transposed to chunk DCT in 2D
+ return torch.einsum("...ijkl, kb, ld -> ...ibjd", x, b, d)
+
+ @torch.no_grad()
+ def encode(self, x):
+ if len(x.shape) > 1: # 2D weights
+ n1 = self.shape_dict[x.shape[0]]
+ n2 = self.shape_dict[x.shape[1]]
+ n1w = self.f_dict[n1].to(x.device)
+ n2w = self.f_dict[n2].to(x.device)
+ self.f_dict[n1] = n1w
+ self.f_dict[n2] = n2w
+
+ x = rearrange(x, "(y h) (x w) -> y h x w", h=n1, w=n2)
+ x = self.einsum_2d(x, n1w, n2w)
+
+ else: # 1D weights
+ n1 = self.shape_dict[x.shape[0]]
+ n1w = self.f_dict[n1].to(x.device)
+ self.f_dict[n1] = n1w
+
+ x = rearrange(x, "(x w) -> x w", w=n1)
+ x = self.einsum_2d(x, n1w)
+
+ return x
+
+ @torch.no_grad()
+ def decode(self, x):
+ if len(x.shape) > 2: # 2D weights
+ n1 = x.shape[2]
+ n2 = x.shape[3]
+ n1w = self.b_dict[n1].to(x.device)
+ n2w = self.b_dict[n2].to(x.device)
+ self.b_dict[n1] = n1w
+ self.b_dict[n2] = n2w
+
+ x = self.einsum_2d_t(x, n1w, n2w)
+ x = rearrange(x, "y h x w -> (y h) (x w)")
+
+ else: # 1D weights
+ n1 = x.shape[1]
+ n1w = self.b_dict[n1].to(x.device)
+ self.b_dict[n1] = n1w
+
+ x = self.einsum_2d_t(x, n1w)
+ x = rearrange(x, "x w -> (x w)")
+
+ return x
+
+
+class CompressDCT:
+ @torch.no_grad()
+ def __init__(self):
+ pass
+
+ def _clamp_topk(self, x, topk):
+ if topk > x.shape[-1]:
+ topk = x.shape[-1]
+ if topk < 1:
+ topk = 1
+ return topk
+
+ @torch.no_grad()
+ def compress(self, x, topk):
+ xshape = x.shape
+ if len(x.shape) > 2: # 2D weights
+ x = rearrange(x, "y x h w -> y x (h w)")
+
+ # Limit topk to max size
+ totalk = x.shape[-1]
+ topk = self._clamp_topk(x, topk)
+
+ idx = torch.topk(x.abs(), k=topk, dim=-1, largest=True, sorted=False).indices
+ val = torch.gather(x, dim=-1, index=idx)
+
+ return idx, val, xshape, totalk
+
+ @torch.no_grad()
+ def decompress(self, p, idx, val, xshape, totalk):
+ x = torch.zeros(xshape, device=p.device, dtype=p.dtype)
+
+ if len(xshape) > 2: # 2D weights
+ x = rearrange(x, "y x h w -> y x (h w)")
+
+ # TODO: Careful, this is nondeterministic across different CUDA devices! might cause errors to accumulate between nodes!
+ x.scatter_reduce_(dim=-1, index=idx, src=val, reduce="mean", include_self=False).reshape(xshape)
+
+ if len(x.shape) > 2: # 2D weights
+ x = rearrange(x, "y x (h w) -> y x h w", h=xshape[2])
+
+ return x
+
+ @torch.no_grad()
+ def batch_decompress(self, p, idx, val, xshape, totalk):
+ idx = torch.concatenate(idx, dim=-1).to(device=p.device)
+ val = torch.concatenate(val, dim=-1).to(device=p.device)
+ return self.decompress(p, idx, val, xshape, totalk)
+
+
+# Code modified and sourced from https://github.com/zh217/torch-dct
+def _dct_fft_impl(v):
+ return torch.view_as_real(torch.fft.fft(v, dim=1))
+
+
+def _idct_irfft_impl(V):
+ return torch.fft.irfft(torch.view_as_complex(V), n=V.shape[1], dim=1)
+
+
+def _dct(x, norm=None):
+ """
+ Discrete Cosine Transform, Type II (a.k.a. the DCT)
+
+ For the meaning of the parameter `norm`, see:
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
+
+ :param x: the input signal
+ :param norm: the normalization, None or 'ortho'
+ :return: the DCT-II of the signal over the last dimension
+ """
+ x_shape = x.shape
+ N = x_shape[-1]
+ x = x.contiguous().view(-1, N)
+
+ v = torch.cat([x[:, ::2], x[:, 1::2].flip([1])], dim=1)
+
+ Vc = _dct_fft_impl(v)
+
+ k = -torch.arange(N, dtype=x.dtype, device=x.device)[None, :] * math.pi / (2 * N)
+ W_r = torch.cos(k)
+ W_i = torch.sin(k)
+
+ V = Vc[:, :, 0] * W_r - Vc[:, :, 1] * W_i
+
+ if norm == "ortho":
+ V[:, 0] /= math.sqrt(N) * 2
+ V[:, 1:] /= math.sqrt(N / 2) * 2
+
+ V = 2 * V.view(*x_shape)
+
+ return V
+
+
+def _idct(X, norm=None):
+ """
+ The inverse to DCT-II, which is a scaled Discrete Cosine Transform, Type III
+
+ Our definition of idct is that idct(dct(x)) == x
+
+ For the meaning of the parameter `norm`, see:
+ https://docs.scipy.org/doc/scipy-0.14.0/reference/generated/scipy.fftpack.dct.html
+
+ :param X: the input signal
+ :param norm: the normalization, None or 'ortho'
+ :return: the inverse DCT-II of the signal over the last dimension
+ """
+
+ x_shape = X.shape
+ N = x_shape[-1]
+
+ X_v = X.contiguous().view(-1, x_shape[-1]) / 2
+
+ if norm == "ortho":
+ X_v[:, 0] *= math.sqrt(N) * 2
+ X_v[:, 1:] *= math.sqrt(N / 2) * 2
+
+ k = torch.arange(x_shape[-1], dtype=X.dtype, device=X.device)[None, :] * math.pi / (2 * N)
+ W_r = torch.cos(k)
+ W_i = torch.sin(k)
+
+ V_t_r = X_v
+ V_t_i = torch.cat([X_v[:, :1] * 0, -X_v.flip([1])[:, :-1]], dim=1)
+
+ V_r = V_t_r * W_r - V_t_i * W_i
+ V_i = V_t_r * W_i + V_t_i * W_r
+
+ V = torch.cat([V_r.unsqueeze(2), V_i.unsqueeze(2)], dim=2)
+
+ v = _idct_irfft_impl(V)
+ x = v.new_zeros(v.shape)
+ x[:, ::2] += v[:, : N - (N // 2)]
+ x[:, 1::2] += v.flip([1])[:, : N // 2]
+
+ return x.view(*x_shape)
+
+
+def _get_prime_divisors(n):
+ divisors = []
+ while n % 2 == 0:
+ divisors.append(2)
+ n //= 2
+ while n % 3 == 0:
+ divisors.append(3)
+ n //= 3
+ i = 5
+ while i * i <= n:
+ for k in (i, i + 2):
+ while n % k == 0:
+ divisors.append(k)
+ n //= k
+ i += 6
+ if n > 1:
+ divisors.append(n)
+ return divisors
+
+
+def _get_divisors(n):
+ divisors = []
+ if n == 1:
+ divisors.append(1)
+ elif n > 1:
+ prime_factors = _get_prime_divisors(n)
+ divisors = [1]
+ last_prime = 0
+ factor = 0
+ slice_len = 0
+ # Find all the products that are divisors of n
+ for prime in prime_factors:
+ if last_prime != prime:
+ slice_len = len(divisors)
+ factor = prime
+ else:
+ factor *= prime
+ for i in range(slice_len):
+ divisors.append(divisors[i] * factor)
+ last_prime = prime
+ divisors.sort()
+ return divisors
+
+
+def _get_smaller_split(n, close_to):
+ all_divisors = _get_divisors(n)
+ for ix, val in enumerate(all_divisors):
+ if val == close_to:
+ return val
+ if val > close_to:
+ if ix == 0:
+ return val
+ return all_divisors[ix - 1]
+ return n
diff --git a/olmo/optim.py b/olmo/optim.py
index 5460ccee..bbbb102a 100644
--- a/olmo/optim.py
+++ b/olmo/optim.py
@@ -1,8 +1,9 @@
+import math
import logging
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, replace
from math import cos, pi, sqrt
-from typing import Any, Dict, List, Optional, Tuple, Union
+from typing import Any, Dict, List, Optional, Tuple, Union, Callable
import torch
import torch.distributed as dist
@@ -14,11 +15,13 @@ from torch.optim.optimizer import Optimizer as OptimizerBase
from . import LayerNormBase
from .config import OptimizerType, SchedulerConfig, SchedulerType, TrainConfig
from .torch_util import get_default_device, is_distributed
+from .demo_utils import TransformDCT, CompressDCT
__all__ = [
"Optimizer",
"LionW",
"AdamW",
+ "DeMo",
"Scheduler",
"CosWithWarmup",
"LinearWithWarmup",
@@ -647,6 +650,177 @@ class AdamW(torch.optim.AdamW, Optimizer):
return metrics
+class DeMo(torch.optim.SGD, Optimizer):
+ def __init__(
+ self,
+ params,
+ compression_decay: float = 0.999,
+ compression_topk: int = 32,
+ compression_chunk: int = 64,
+ weight_decay: float = 0.0,
+ process_group: Optional[dist.ProcessGroup] = None,
+ record_update_metrics: bool = False,
+ selective_updates: bool = False,
+ **kwargs,
+ ):
+ super().__init__(
+ params,
+ foreach=False,
+ momentum=0.0,
+ dampening=0.0,
+ nesterov=False,
+ maximize=False,
+ weight_decay=0.0,
+ **kwargs,
+ )
+
+ # Need to set these here just like in our base `Optimizer` class since our `Optimizer.__init__`
+ # won't be called.
+ self._record_update_metrics = record_update_metrics
+ self._collecting_metrics = False
+ self._selective_updates = selective_updates
+
+ self.compression_decay = compression_decay
+ self.compression_chunk = compression_chunk
+ self.compression_topk = compression_topk
+ self.process_group = process_group
+ self.weight_decay = weight_decay
+
+ if self.compression_topk <= 0:
+ raise ValueError("topk_size has to be positive")
+ if self.compression_chunk <= 0:
+ raise ValueError("chunk_size has to be positive")
+ if self.compression_decay < 0:
+ raise ValueError("Negative compression_decay is currently not supported")
+ if self.compression_decay >= 1:
+ raise ValueError("Values of compression_decay bigger or equal to 1.0 is currently not supported")
+
+ self.demo_state = {}
+ self._init_demo_states()
+ self._init_opt_parameters()
+
+ self.default_dtype = self._find_dtype()
+ self.transform = TransformDCT(self.param_groups, self.compression_chunk)
+ self.compress = CompressDCT()
+
+ def _find_dtype(self):
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.requires_grad:
+ return p.dtype
+ return torch.float32
+
+ def _init_demo_states(self):
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.requires_grad:
+ self.demo_state[p] = {}
+
+ def _state_parameter(self, p):
+ if p not in self.demo_state:
+ self.demo_state[p] = {}
+ return self.demo_state[p]
+
+ def _init_opt_parameters(self):
+ for group in self.param_groups:
+ for p in group["params"]:
+ if p.requires_grad:
+ state = self._state_parameter(p)
+
+ state["step"] = 0
+ state["delta"] = torch.zeros_like(p)
+
+ def _demo_all_gather(self, sparse_idx, sparse_val):
+ world_size = dist.get_world_size() if self.process_group is None else self.process_group.size()
+
+ # Gather all the idx and vals
+ sparse_idx_list = [torch.zeros_like(sparse_idx) for wi in range(world_size)]
+ sparse_val_list = [torch.zeros_like(sparse_val) for wi in range(world_size)]
+
+ sparse_idx_handle = dist.all_gather(sparse_idx_list, sparse_idx, group=self.process_group, async_op=True)
+ sparse_val_handle = dist.all_gather(sparse_val_list, sparse_val, group=self.process_group, async_op=True)
+
+ sparse_idx_handle.wait()
+ sparse_val_handle.wait()
+
+ return sparse_idx_list, sparse_val_list
+
+
+ @torch.no_grad()
+ def step(self, closure: Callable | None = None):
+
+ self.data_transmit = 0
+ self.data_receive = 0
+
+ for group in self.param_groups:
+ lr = group["lr"]
+ for p in group["params"]:
+ if not p.requires_grad:
+ continue
+ state = self._state_parameter(p)
+
+ # Update step
+ state["step"] += 1
+
+ # Step-Weight decay
+ if self.weight_decay != 0.0:
+ p.data.mul_(1.0 - lr * self.weight_decay)
+
+ # Decay delta
+ if self.compression_decay != 1:
+ state["delta"].mul_(self.compression_decay)
+
+ # Add delta to new gradient
+ state["delta"].add_(p.grad, alpha=lr)
+
+ # Compress delta
+ sparse_idx, sparse_val, xshape, totalk = self.compress.compress(
+ self.transform.encode(state["delta"]), self.compression_topk
+ )
+
+ # Estimate transmitted delta
+ transmit_grad = self.transform.decode(
+ self.compress.decompress(p, sparse_idx, sparse_val, xshape, totalk)
+ )
+
+ # Remove transmitted from delta
+ state["delta"].sub_(transmit_grad)
+
+ # All-gather
+ sparse_idx_gather, sparse_val_gather = self._demo_all_gather(sparse_idx, sparse_val)
+
+ # Log I/O data size
+ self.data_transmit += sparse_idx.nbytes + sparse_val.nbytes
+ for si, v in zip(sparse_idx_gather, sparse_val_gather):
+ self.data_receive += si.nbytes + v.nbytes
+
+ # Decode grad from all nodes
+ new_grad = self.transform.decode(
+ self.compress.batch_decompress(p, sparse_idx_gather, sparse_val_gather, xshape, totalk)
+ )
+
+ # Set grad to values
+ if p.grad is None:
+ p.grad = new_grad
+ else:
+ p.grad.copy_(new_grad)
+
+ # Sign-SGD
+ p.grad.sign_()
+
+ # SGD step
+ return super().step(closure)
+
+
+ def get_post_step_metrics(
+ self, module: nn.Module, process_group: Optional[dist.ProcessGroup] = None
+ ) -> Dict[str, torch.Tensor]:
+ return {
+ "data_receive": torch.tensor(self.data_receive, device=get_default_device()),
+ "data_transmit": torch.tensor(self.data_transmit, device=get_default_device()),
+ }
+
+
@dataclass
class Scheduler(metaclass=ABCMeta):
# NOTE: these fields are not given default values because otherwise dataclasses complains
@@ -950,6 +1124,17 @@ def build_optimizer(cfg: TrainConfig, model: nn.Module) -> Optimizer:
selective_updates=cfg.optimizer.selective_updates,
eps=cfg.optimizer.eps,
)
+ elif cfg.optimizer.name == OptimizerType.demo:
+ return DeMo(
+ param_groups,
+ compression_decay=cfg.optimizer.compression_decay,
+ compression_topk=cfg.optimizer.compression_topk,
+ compression_chunk=cfg.optimizer.compression_chunk,
+ weight_decay=cfg.optimizer.weight_decay,
+ process_group=None, # TODO: fix for hybrid sharding
+ record_update_metrics=cfg.optimizer.record_update_metrics,
+ selective_updates=cfg.optimizer.selective_updates,
+ )
else:
raise NotImplementedError
diff --git a/olmo/train.py b/olmo/train.py
index 34105500..77e758b9 100644
--- a/olmo/train.py
+++ b/olmo/train.py
@@ -35,6 +35,7 @@ from .config import (
CheckpointType,
DDPGradSyncMode,
DistributedStrategy,
+ OptimizerType,
SchedulerUnits,
ShardedCheckpointerType,
SpeedMonitorConfig,
@@ -44,7 +45,7 @@ from .data import IterableDataset
from .eval import Evaluator
from .exceptions import OLMoConfigurationError
from .model import OLMo
-from .optim import Optimizer, Scheduler
+from .optim import DeMo, Optimizer, Scheduler
from .torch_util import (
barrier,
gc_cuda,
@@ -785,10 +786,13 @@ class Trainer:
if (
self.cfg.distributed_strategy == DistributedStrategy.ddp
and self.cfg.ddp is not None
- and self.cfg.ddp.grad_sync_mode == DDPGradSyncMode.batch
+ and self.cfg.ddp.grad_sync_mode != DDPGradSyncMode.micro_batch
):
- if micro_batch_idx != num_micro_batches - 1:
+ if (self.cfg.ddp.grad_sync_mode == DDPGradSyncMode.batch and micro_batch_idx != num_micro_batches - 1) \
+ or self.cfg.ddp.grad_sync_mode == DDPGradSyncMode.none:
grad_sync_context = self.dist_model.no_sync
+ elif self.cfg.distributed_strategy == DistributedStrategy.fsdp and self.cfg.fsdp is not None and self.cfg.fsdp.disable_grad_sync:
+ grad_sync_context = self.dist_model.no_sync
# Register output hooks
output_hooks: List[torch.utils.hooks.RemovableHandle] = []
diff --git a/scripts/train.py b/scripts/train.py
index 1f735309..d20d0092 100644
--- a/scripts/train.py
+++ b/scripts/train.py
@@ -20,6 +20,7 @@ from olmo.config import (
CheckpointType,
DDPGradSyncMode,
DistributedStrategy,
+ OptimizerType,
TrainConfig,
)
from olmo.data import build_train_dataloader
@@ -138,6 +139,8 @@ def main(cfg: TrainConfig) -> None:
if cfg.distributed_strategy == DistributedStrategy.ddp:
log.info("Wrapping model with DDP...")
assert cfg.ddp is not None, "DistributedStrategy ddp needs cfg.ddp to be set!"
+ if cfg.optimizer.name == OptimizerType.demo and cfg.ddp.grad_sync_mode != DDPGradSyncMode.none:
+ raise OLMoConfigurationError("DeMo requires that `ddp.grad_sync_mode` be set to `none`.")
if cfg.model.init_device != "cuda":
raise OLMoConfigurationError("DDP does not work with init_device set to anything other than `cuda`.")
@@ -155,6 +158,8 @@ def main(cfg: TrainConfig) -> None:
# Wrap the model in FSDP.
log.info("Wrapping model with FSDP...")
assert cfg.fsdp is not None, "DistributedStrategy fsdp needs cfg.fsdp to be set!"
+ if cfg.optimizer.name == OptimizerType.demo and not cfg.fsdp.disable_grad_sync:
+ raise OLMoConfigurationError("DeMo requires that `fsdp.disable_grad_sync` be set to `true`.")
wrap_policy = olmo_model.get_fsdp_wrap_policy(cfg.fsdp.wrapping_strategy)
if version.parse(torch.__version__) >= version.parse("2.1.0"):
--
2.34.1