1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import logging
8
+ import os
9
+ import sys
10
+ from datetime import timedelta
11
+
12
+ REPLICA_GROUP_ID = int (os .environ .get ("REPLICA_GROUP_ID" , 0 ))
13
+ os .environ ["CUDA_VISIBLE_DEVICES" ] = str (REPLICA_GROUP_ID % 4 )
14
+ os .environ ["NCCL_HOSTID" ] = str (REPLICA_GROUP_ID )
15
+
16
+ import torch
17
+ import torchvision
18
+ import torchvision .transforms as transforms
19
+ from torch import nn , optim
20
+ from torch .distributed .elastic .multiprocessing .errors import record
21
+ from torchdata .stateful_dataloader import StatefulDataLoader
22
+ import time
23
+ from torchft import (
24
+ DistributedSampler ,
25
+ Manager ,
26
+ ProcessGroupGloo ,
27
+ ProcessGroupNCCL ,
28
+ )
29
+ from torchft .local_sgd import DiLoCo
30
+ from torchft .checkpointing .pg_transport import PGTransport
31
+
32
+ logging .basicConfig (level = logging .INFO )
33
+
34
+
35
+ @record
36
+ def main () -> None :
37
+ REPLICA_GROUP_ID = int (os .environ .get ("REPLICA_GROUP_ID" , 0 ))
38
+ NUM_REPLICA_GROUPS = int (os .environ .get ("NUM_REPLICA_GROUPS" , 2 ))
39
+
40
+ transform = transforms .Compose (
41
+ [transforms .ToTensor (), transforms .Normalize ((0.5 , 0.5 , 0.5 ), (0.5 , 0.5 , 0.5 ))]
42
+ )
43
+ trainset = torchvision .datasets .CIFAR10 (
44
+ root = "./cifar" , train = True , download = True , transform = transform
45
+ )
46
+
47
+ # This shards the training set across all ranks and replica groups. We manage
48
+ # the dataloaders on a per replica group basis with the assumption that the
49
+ # majority of groups will be available so few batches will be dropped.
50
+ sampler = DistributedSampler (
51
+ trainset ,
52
+ replica_group_id = REPLICA_GROUP_ID ,
53
+ num_replica_groups = NUM_REPLICA_GROUPS ,
54
+ group_rank = 0 ,
55
+ # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.
56
+ num_replicas = 1 ,
57
+ shuffle = True ,
58
+ )
59
+
60
+ # This uses the torchdata StatefulDataLoader to be able to checkpoint and
61
+ # restore the per worker dataloader position.
62
+ trainloader = StatefulDataLoader (
63
+ trainset , batch_size = 64 , num_workers = 2 , sampler = sampler
64
+ )
65
+
66
+
67
+ device = "cuda" if torch .cuda .is_available () else "cpu"
68
+ pg = (
69
+ ProcessGroupNCCL (
70
+ timeout = timedelta (seconds = 30 ),
71
+ )
72
+ if torch .cuda .is_available ()
73
+ else ProcessGroupGloo (timeout = timedelta (seconds = 5 ))
74
+ )
75
+
76
+ transport = PGTransport (
77
+ pg ,
78
+ timeout = timedelta (seconds = 10 ),
79
+ device = ("cuda" if torch .cuda .is_available () else "cpu" ),
80
+ )
81
+
82
+ class Net (nn .Module ):
83
+ def __init__ (self ):
84
+ super ().__init__ ()
85
+ self .cnn = nn .Sequential (
86
+ nn .Conv2d (3 , 6 , 5 ),
87
+ nn .ReLU (),
88
+ nn .MaxPool2d (2 , 2 ),
89
+ nn .Conv2d (6 , 16 , 5 ),
90
+ nn .ReLU (),
91
+ nn .MaxPool2d (2 , 2 ),
92
+ )
93
+
94
+ final_dim = 10
95
+ # We add a useless 1GB intermediate layer so we spend more time in dist
96
+ # communication so injected failures are more likely to cause issues
97
+ # if they exist.
98
+ target_size = 1_000_000_000
99
+ self .useless = nn .Embedding (target_size // final_dim // 4 , final_dim )
100
+
101
+ self .classifier = nn .Sequential (
102
+ nn .Linear (16 * 5 * 5 , 120 ),
103
+ nn .ReLU (),
104
+ nn .Linear (120 , 84 ),
105
+ nn .ReLU (),
106
+ nn .Linear (84 , final_dim ),
107
+ )
108
+
109
+ def forward (self , x ):
110
+ x = self .cnn (x )
111
+ x = torch .flatten (x , 1 ) # flatten all dimensions except batch
112
+ x = self .classifier (x )
113
+ x += self .useless .weight [0 ]
114
+ return x
115
+
116
+ m = Net ().to (device )
117
+ inner_optimizer = optim .AdamW (
118
+ m .parameters (), lr = 4e-4 , weight_decay = 0.1 , betas = (0.9 , 0.95 )
119
+ )
120
+ outer_optimizer = optim .SGD (
121
+ m .parameters (), lr = 0.7 , momentum = 0.9 , nesterov = True
122
+ )
123
+ criterion = nn .CrossEntropyLoss ()
124
+
125
+ def load_state_dict (state_dict ):
126
+ m .load_state_dict (state_dict ["model" ])
127
+ m .to (device )
128
+ diloco .original_parameters = state_dict ["original_params" ]
129
+ for name in diloco .original_parameters .keys ():
130
+ diloco .original_parameters [name ] = diloco .original_parameters [name ].to (
131
+ device
132
+ )
133
+ inner_optimizer .load_state_dict (state_dict ["inner_optim" ])
134
+ outer_optimizer .load_state_dict (state_dict ["outer_optim" ])
135
+
136
+ def state_dict ():
137
+ return {
138
+ "model" : m .state_dict (),
139
+ "original_params" : diloco .original_parameters ,
140
+ "inner_optim" : inner_optimizer .state_dict (),
141
+ "outer_optim" : outer_optimizer .state_dict (),
142
+ }
143
+
144
+ manager = Manager (
145
+ pg = pg ,
146
+ min_replica_size = 1 ,
147
+ load_state_dict = load_state_dict ,
148
+ state_dict = state_dict ,
149
+ replica_id = f"train_ddp_{ REPLICA_GROUP_ID } " ,
150
+ timeout = timedelta (seconds = 30 ),
151
+ checkpoint_transport = transport ,
152
+ use_async_quorum = False
153
+ )
154
+
155
+ print (m )
156
+ num_params = sum (p .numel () for p in m .parameters ())
157
+ print (f"Total number of parameters: { num_params } " )
158
+
159
+ sort_by_keyword = "self_" + device + "_time_total"
160
+
161
+ def trace_handler (p ):
162
+ output = p .key_averages ().table (
163
+ sort_by = sort_by_keyword ,
164
+ row_limit = 100 ,
165
+ )
166
+ print (output )
167
+ p .export_chrome_trace ("/tmp/trace_" + str (p .step_num ) + ".json" )
168
+
169
+ # You can use an epoch based training but with faults it's easier to use step
170
+ # based training.
171
+ prof = torch .profiler .profile (
172
+ schedule = torch .profiler .schedule (wait = 5 , warmup = 1 , active = 10 , repeat = 2 ),
173
+ on_trace_ready = trace_handler ,
174
+ record_shapes = True ,
175
+ profile_memory = True ,
176
+ )
177
+
178
+ prof .start ()
179
+
180
+ num_local_steps = 0
181
+ sync_every = 100
182
+ with DiLoCo (
183
+ manager ,
184
+ m ,
185
+ inner_optimizer ,
186
+ outer_optimizer ,
187
+ backup_device = device ,
188
+ sync_every = sync_every ,
189
+ ) as diloco :
190
+ while True :
191
+ for i , (inputs , labels ) in enumerate (trainloader ):
192
+ prof .step ()
193
+
194
+ inputs = inputs .to (device )
195
+ labels = labels .to (device )
196
+
197
+ # must be called at the beginning of each train loop
198
+ # Quorum computation is triggered here but only needed in the backwards pass.
199
+ inner_optimizer .zero_grad ()
200
+
201
+ out = m (inputs )
202
+ loss = criterion (out , labels )
203
+
204
+ # Gradient allreduce overlaps with the backwards pass.
205
+ loss .backward ()
206
+
207
+ # must be called at the end of the train loop
208
+ # This may not actually step the optimizer if an error occured during grad allreduce.
209
+ inner_optimizer .step ()
210
+ num_local_steps += 1
211
+
212
+ if manager .current_step () % 100 == 0 :
213
+ print (f"[{ manager .current_step ()} ] loss = { loss .item ()} " )
214
+
215
+ if num_local_steps % sync_every == 0 :
216
+ print (f"Number of inner optimizer steps completed: { num_local_steps } " )
217
+
218
+ # TODO (by the user): periodically checkpoint model, optim, manager and dataloader
219
+
220
+ # You typically want to checkpoint dataloader frequently (every step?) to
221
+ # avoid repeated batches as it's replica group specific.
222
+
223
+ # Model, optim and manager checkpoints can be done more infrequently as
224
+ # they're shared across all groups and will load from existing replicas as
225
+ # long as not every worker goes down.
226
+
227
+ if manager .current_step () >= 10000 :
228
+ # complete training
229
+ prof .stop ()
230
+ exit ()
231
+ time .sleep (0.01 )
232
+
233
+
234
+ if __name__ == "__main__" :
235
+ main ()
0 commit comments