Skip to content

Commit 3641dca

Browse files
localsgd examples added
1 parent 7b550aa commit 3641dca

File tree

2 files changed

+451
-0
lines changed

2 files changed

+451
-0
lines changed

examples/train_diloco.py

Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
import os
9+
import sys
10+
from datetime import timedelta
11+
12+
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
13+
os.environ["CUDA_VISIBLE_DEVICES"] = str(REPLICA_GROUP_ID % 4)
14+
os.environ["NCCL_HOSTID"] = str(REPLICA_GROUP_ID)
15+
16+
import torch
17+
import torchvision
18+
import torchvision.transforms as transforms
19+
from torch import nn, optim
20+
from torch.distributed.elastic.multiprocessing.errors import record
21+
from torchdata.stateful_dataloader import StatefulDataLoader
22+
import time
23+
from torchft import (
24+
DistributedSampler,
25+
Manager,
26+
ProcessGroupGloo,
27+
ProcessGroupNCCL,
28+
)
29+
from torchft.local_sgd import DiLoCo
30+
from torchft.checkpointing.pg_transport import PGTransport
31+
32+
logging.basicConfig(level=logging.INFO)
33+
34+
35+
@record
36+
def main() -> None:
37+
REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0))
38+
NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2))
39+
40+
transform = transforms.Compose(
41+
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
42+
)
43+
trainset = torchvision.datasets.CIFAR10(
44+
root="./cifar", train=True, download=True, transform=transform
45+
)
46+
47+
# This shards the training set across all ranks and replica groups. We manage
48+
# the dataloaders on a per replica group basis with the assumption that the
49+
# majority of groups will be available so few batches will be dropped.
50+
sampler = DistributedSampler(
51+
trainset,
52+
replica_group_id=REPLICA_GROUP_ID,
53+
num_replica_groups=NUM_REPLICA_GROUPS,
54+
group_rank=0,
55+
# for DDP we can use replica groups of size 1, FSDP/PP/CP would need more.
56+
num_replicas=1,
57+
shuffle=True,
58+
)
59+
60+
# This uses the torchdata StatefulDataLoader to be able to checkpoint and
61+
# restore the per worker dataloader position.
62+
trainloader = StatefulDataLoader(
63+
trainset, batch_size=64, num_workers=2, sampler=sampler
64+
)
65+
66+
67+
device = "cuda" if torch.cuda.is_available() else "cpu"
68+
pg = (
69+
ProcessGroupNCCL(
70+
timeout=timedelta(seconds=30),
71+
)
72+
if torch.cuda.is_available()
73+
else ProcessGroupGloo(timeout=timedelta(seconds=5))
74+
)
75+
76+
transport = PGTransport(
77+
pg,
78+
timeout=timedelta(seconds=10),
79+
device=("cuda" if torch.cuda.is_available() else "cpu"),
80+
)
81+
82+
class Net(nn.Module):
83+
def __init__(self):
84+
super().__init__()
85+
self.cnn = nn.Sequential(
86+
nn.Conv2d(3, 6, 5),
87+
nn.ReLU(),
88+
nn.MaxPool2d(2, 2),
89+
nn.Conv2d(6, 16, 5),
90+
nn.ReLU(),
91+
nn.MaxPool2d(2, 2),
92+
)
93+
94+
final_dim = 10
95+
# We add a useless 1GB intermediate layer so we spend more time in dist
96+
# communication so injected failures are more likely to cause issues
97+
# if they exist.
98+
target_size = 1_000_000_000
99+
self.useless = nn.Embedding(target_size // final_dim // 4, final_dim)
100+
101+
self.classifier = nn.Sequential(
102+
nn.Linear(16 * 5 * 5, 120),
103+
nn.ReLU(),
104+
nn.Linear(120, 84),
105+
nn.ReLU(),
106+
nn.Linear(84, final_dim),
107+
)
108+
109+
def forward(self, x):
110+
x = self.cnn(x)
111+
x = torch.flatten(x, 1) # flatten all dimensions except batch
112+
x = self.classifier(x)
113+
x += self.useless.weight[0]
114+
return x
115+
116+
m = Net().to(device)
117+
inner_optimizer = optim.AdamW(
118+
m.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
119+
)
120+
outer_optimizer = optim.SGD(
121+
m.parameters(), lr=0.7, momentum=0.9, nesterov=True
122+
)
123+
criterion = nn.CrossEntropyLoss()
124+
125+
def load_state_dict(state_dict):
126+
m.load_state_dict(state_dict["model"])
127+
m.to(device)
128+
diloco.original_parameters = state_dict["original_params"]
129+
for name in diloco.original_parameters.keys():
130+
diloco.original_parameters[name] = diloco.original_parameters[name].to(
131+
device
132+
)
133+
inner_optimizer.load_state_dict(state_dict["inner_optim"])
134+
outer_optimizer.load_state_dict(state_dict["outer_optim"])
135+
136+
def state_dict():
137+
return {
138+
"model": m.state_dict(),
139+
"original_params": diloco.original_parameters,
140+
"inner_optim": inner_optimizer.state_dict(),
141+
"outer_optim": outer_optimizer.state_dict(),
142+
}
143+
144+
manager = Manager(
145+
pg=pg,
146+
min_replica_size=1,
147+
load_state_dict=load_state_dict,
148+
state_dict=state_dict,
149+
replica_id=f"train_ddp_{REPLICA_GROUP_ID}",
150+
timeout=timedelta(seconds=30),
151+
checkpoint_transport=transport,
152+
use_async_quorum=False
153+
)
154+
155+
print(m)
156+
num_params = sum(p.numel() for p in m.parameters())
157+
print(f"Total number of parameters: {num_params}")
158+
159+
sort_by_keyword = "self_" + device + "_time_total"
160+
161+
def trace_handler(p):
162+
output = p.key_averages().table(
163+
sort_by=sort_by_keyword,
164+
row_limit=100,
165+
)
166+
print(output)
167+
p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json")
168+
169+
# You can use an epoch based training but with faults it's easier to use step
170+
# based training.
171+
prof = torch.profiler.profile(
172+
schedule=torch.profiler.schedule(wait=5, warmup=1, active=10, repeat=2),
173+
on_trace_ready=trace_handler,
174+
record_shapes=True,
175+
profile_memory=True,
176+
)
177+
178+
prof.start()
179+
180+
num_local_steps = 0
181+
sync_every = 100
182+
with DiLoCo(
183+
manager,
184+
m,
185+
inner_optimizer,
186+
outer_optimizer,
187+
backup_device=device,
188+
sync_every=sync_every,
189+
) as diloco:
190+
while True:
191+
for i, (inputs, labels) in enumerate(trainloader):
192+
prof.step()
193+
194+
inputs = inputs.to(device)
195+
labels = labels.to(device)
196+
197+
# must be called at the beginning of each train loop
198+
# Quorum computation is triggered here but only needed in the backwards pass.
199+
inner_optimizer.zero_grad()
200+
201+
out = m(inputs)
202+
loss = criterion(out, labels)
203+
204+
# Gradient allreduce overlaps with the backwards pass.
205+
loss.backward()
206+
207+
# must be called at the end of the train loop
208+
# This may not actually step the optimizer if an error occured during grad allreduce.
209+
inner_optimizer.step()
210+
num_local_steps += 1
211+
212+
if manager.current_step() % 100 == 0:
213+
print(f"[{manager.current_step()}] loss = {loss.item()}")
214+
215+
if num_local_steps % sync_every == 0:
216+
print(f"Number of inner optimizer steps completed: {num_local_steps}")
217+
218+
# TODO (by the user): periodically checkpoint model, optim, manager and dataloader
219+
220+
# You typically want to checkpoint dataloader frequently (every step?) to
221+
# avoid repeated batches as it's replica group specific.
222+
223+
# Model, optim and manager checkpoints can be done more infrequently as
224+
# they're shared across all groups and will load from existing replicas as
225+
# long as not every worker goes down.
226+
227+
if manager.current_step() >= 10000:
228+
# complete training
229+
prof.stop()
230+
exit()
231+
time.sleep(0.01)
232+
233+
234+
if __name__ == "__main__":
235+
main()

0 commit comments

Comments
 (0)