-
Notifications
You must be signed in to change notification settings - Fork 5.3k
/
Copy pathDDP-script.py
212 lines (168 loc) · 6.8 KB
/
DDP-script.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
# Copyright (c) Sebastian Raschka under Apache License 2.0 (see LICENSE.txt).
# Source for "Build a Large Language Model From Scratch"
# - https://www.manning.com/books/build-a-large-language-model-from-scratch
# Code: https://github.com/rasbt/LLMs-from-scratch
# Appendix A: Introduction to PyTorch (Part 3)
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
# NEW imports:
import os
import platform
import torch.multiprocessing as mp
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import init_process_group, destroy_process_group
# NEW: function to initialize a distributed process group (1 process / GPU)
# this allows communication among processes
def ddp_setup(rank, world_size):
"""
Arguments:
rank: a unique process ID
world_size: total number of processes in the group
"""
# rank of machine running rank:0 process
# here, we assume all GPUs are on the same machine
os.environ["MASTER_ADDR"] = "localhost"
# any free port on the machine
os.environ["MASTER_PORT"] = "12345"
# initialize process group
if platform.system() == "Windows":
# Disable libuv because PyTorch for Windows isn't built with support
os.environ["USE_LIBUV"] = "0"
# Windows users may have to use "gloo" instead of "nccl" as backend
# gloo: Facebook Collective Communication Library
init_process_group(backend="gloo", rank=rank, world_size=world_size)
else:
# nccl: NVIDIA Collective Communication Library
init_process_group(backend="nccl", rank=rank, world_size=world_size)
torch.cuda.set_device(rank)
class ToyDataset(Dataset):
def __init__(self, X, y):
self.features = X
self.labels = y
def __getitem__(self, index):
one_x = self.features[index]
one_y = self.labels[index]
return one_x, one_y
def __len__(self):
return self.labels.shape[0]
class NeuralNetwork(torch.nn.Module):
def __init__(self, num_inputs, num_outputs):
super().__init__()
self.layers = torch.nn.Sequential(
# 1st hidden layer
torch.nn.Linear(num_inputs, 30),
torch.nn.ReLU(),
# 2nd hidden layer
torch.nn.Linear(30, 20),
torch.nn.ReLU(),
# output layer
torch.nn.Linear(20, num_outputs),
)
def forward(self, x):
logits = self.layers(x)
return logits
def prepare_dataset():
X_train = torch.tensor([
[-1.2, 3.1],
[-0.9, 2.9],
[-0.5, 2.6],
[2.3, -1.1],
[2.7, -1.5]
])
y_train = torch.tensor([0, 0, 0, 1, 1])
X_test = torch.tensor([
[-0.8, 2.8],
[2.6, -1.6],
])
y_test = torch.tensor([0, 1])
# Uncomment these lines to increase the dataset size to run this script on up to 8 GPUs:
# factor = 4
# X_train = torch.cat([X_train + torch.randn_like(X_train) * 0.1 for _ in range(factor)])
# y_train = y_train.repeat(factor)
# X_test = torch.cat([X_test + torch.randn_like(X_test) * 0.1 for _ in range(factor)])
# y_test = y_test.repeat(factor)
train_ds = ToyDataset(X_train, y_train)
test_ds = ToyDataset(X_test, y_test)
train_loader = DataLoader(
dataset=train_ds,
batch_size=2,
shuffle=False, # NEW: False because of DistributedSampler below
pin_memory=True,
drop_last=True,
# NEW: chunk batches across GPUs without overlapping samples:
sampler=DistributedSampler(train_ds) # NEW
)
test_loader = DataLoader(
dataset=test_ds,
batch_size=2,
shuffle=False,
)
return train_loader, test_loader
# NEW: wrapper
def main(rank, world_size, num_epochs):
ddp_setup(rank, world_size) # NEW: initialize process groups
train_loader, test_loader = prepare_dataset()
model = NeuralNetwork(num_inputs=2, num_outputs=2)
model.to(rank)
optimizer = torch.optim.SGD(model.parameters(), lr=0.5)
model = DDP(model, device_ids=[rank]) # NEW: wrap model with DDP
# the core model is now accessible as model.module
for epoch in range(num_epochs):
# NEW: Set sampler to ensure each epoch has a different shuffle order
train_loader.sampler.set_epoch(epoch)
model.train()
for features, labels in train_loader:
features, labels = features.to(rank), labels.to(rank) # New: use rank
logits = model(features)
loss = F.cross_entropy(logits, labels) # Loss function
optimizer.zero_grad()
loss.backward()
optimizer.step()
# LOGGING
print(f"[GPU{rank}] Epoch: {epoch+1:03d}/{num_epochs:03d}"
f" | Batchsize {labels.shape[0]:03d}"
f" | Train/Val Loss: {loss:.2f}")
model.eval()
try:
train_acc = compute_accuracy(model, train_loader, device=rank)
print(f"[GPU{rank}] Training accuracy", train_acc)
test_acc = compute_accuracy(model, test_loader, device=rank)
print(f"[GPU{rank}] Test accuracy", test_acc)
####################################################
# NEW (not in the book):
except ZeroDivisionError as e:
raise ZeroDivisionError(
f"{e}\n\nThis script is designed for 2 GPUs. You can run it as:\n"
"CUDA_VISIBLE_DEVICES=0,1 python DDP-script.py\n"
f"Or, to run it on {torch.cuda.device_count()} GPUs, uncomment the code on lines 103 to 107."
)
####################################################
destroy_process_group() # NEW: cleanly exit distributed mode
def compute_accuracy(model, dataloader, device):
model = model.eval()
correct = 0.0
total_examples = 0
for idx, (features, labels) in enumerate(dataloader):
features, labels = features.to(device), labels.to(device)
with torch.no_grad():
logits = model(features)
predictions = torch.argmax(logits, dim=1)
compare = labels == predictions
correct += torch.sum(compare)
total_examples += len(compare)
return (correct / total_examples).item()
if __name__ == "__main__":
# This script may not work for GPUs > 2 due to the small dataset
# Run `CUDA_VISIBLE_DEVICES=0,1 python DDP-script.py` if you have GPUs > 2
print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("Number of GPUs available:", torch.cuda.device_count())
torch.manual_seed(123)
# NEW: spawn new processes
# note that spawn will automatically pass the rank
num_epochs = 3
world_size = torch.cuda.device_count()
mp.spawn(main, args=(world_size, num_epochs), nprocs=world_size)
# nprocs=world_size spawns one process per GPU