Skip to content

Commit 2b43bfc

Browse files
committed
wait for futures while syncing fragments
Summary: - we current wait for pg work's future when preparing for a fragment - if we use gloo, this blocks the cpu - move the wait call to when we perform the actual sync of the fragment - the manager allreduce also returns the work object so we can wait for that as well when performing the sync - use http transport instead of pg transport -- pg transport fails to resolve address when running locally - deep copy the state dict for sending checkpoint because if the replica moves to the next step, the state dict can change before the checkpoint is sent Test Plan: gloo overlaps now <img width="1284" height="662" alt="image" src="https://github.com/user-attachments/assets/e9b88e52-8053-432b-83a3-e689bcc4f9d4" /> nccl still overlaps <img width="1283" height="664" alt="image" src="https://github.com/user-attachments/assets/cbd0a352-1529-42f7-b8d9-d45bd0e84a97" />
1 parent 949a981 commit 2b43bfc

13 files changed

+129
-95
lines changed

torchft/collectives.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
AllreduceOptions,
1919
AllToAllOptions,
2020
ReduceScatterOptions,
21+
Work,
2122
)
2223
from torch.futures import Future
2324

@@ -288,7 +289,7 @@ def allreduce_quantized(
288289
opts: AllreduceOptions | ReduceOp,
289290
process_group: "ProcessGroup",
290291
sync_stream: cuda.Stream | None = None,
291-
) -> Future[list[torch.Tensor]]:
292+
) -> Work:
292293
"""
293294
Performs a quantized all-reduce operation on a list of tensors.
294295
@@ -379,6 +380,14 @@ def allreduce_quantized(
379380
[torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
380381
_to_allgather_options(allreduce_opts),
381382
)
383+
384+
# NOTE: This is not supposed to be used with gloo, only with NCCL.
385+
# So we setup the stream dependency here by calling work.wait(),
386+
# which doesn't block the CPU.
387+
#
388+
# The future callback below will run after the work has been
389+
# completed.
390+
382391
work.wait()
383392
fut = work.get_future()
384393

@@ -392,4 +401,4 @@ def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]:
392401
return tensors
393402

394403
fut = fut.then(callback)
395-
return fut
404+
return work

torchft/collectives_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def _run_all_reduce_collective(
9494
)
9595
]
9696

97-
fut = allreduce_quantized(tensors, reduce_op, pg)
98-
fut.wait()
97+
work = allreduce_quantized(tensors, reduce_op, pg)
98+
work.get_future().wait()
9999

100100
work = pg.allreduce([expected], reduce_op)
101101
work.get_future().wait()

torchft/ddp.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> N
6868
def _comm_hook(
6969
state: "Manager", bucket: dist.GradBucket
7070
) -> torch.futures.Future[torch.Tensor]:
71-
return state.allreduce(bucket.buffer())
71+
work = state.allreduce(bucket.buffer())
72+
return work.get_future()
7273

7374

7475
class PureDistributedDataParallel(nn.Module):

torchft/ddp_test.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
import torch
1111
import torch.distributed as dist
1212
from torch import nn
13+
from torch.distributed.distributed_c10d import Work
1314
from torch.futures import Future
1415

1516
from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
1617
from torchft.manager import Manager
1718
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
19+
from torchft.work import _DummyWork
1820

1921

2022
class TestDDP(TestCase):
@@ -39,14 +41,14 @@ def test_ddp(self) -> None:
3941

4042
call_count = 0
4143

42-
def allreduce(tensor: torch.Tensor) -> Future[torch.Tensor]:
44+
def allreduce(
45+
tensor: torch.Tensor,
46+
) -> Work:
4347
nonlocal call_count
4448

4549
call_count += 1
4650

47-
fut = Future() # pyre-fixme[29]: not a function
48-
fut.set_result(tensor)
49-
return fut
51+
return _DummyWork(tensor)
5052

5153
manager.allreduce = allreduce
5254

torchft/local_sgd.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919
import torch.distributed as dist
2020
from torch import nn, optim
21+
from torch.distributed.distributed_c10d import Work
2122
from torch.distributed.tensor import DTensor
2223
from torch.nn.parameter import Parameter
2324
from torch.optim.optimizer import Optimizer
@@ -154,7 +155,8 @@ def _average(self) -> list[torch.Tensor]:
154155
for p in self._model.parameters():
155156
# Create a new tensor to store the averaged parameter
156157
avg_param = extract_local_tensor(p)
157-
works.append(self._manager.allreduce(avg_param))
158+
work = self._manager.allreduce(avg_param)
159+
works.append(work)
158160
averaged_parameters.append(avg_param)
159161
for work in works:
160162
work.wait()
@@ -200,7 +202,7 @@ def __init__(
200202
self._outer_optimizer = outer_optimizer
201203

202204
# Stores pending all reduce
203-
self._allreduce_futures: list[torch.futures.Future[torch.Tensor]] = []
205+
self._allreduce_work: list[Work] = []
204206
self._stream: Optional[torch.cuda.Stream] = (
205207
torch.cuda.Stream() if torch.cuda.is_available() else None
206208
)
@@ -368,15 +370,15 @@ def wait(self) -> None:
368370
"""
369371
Waits for the previously scheduled allreduce to finish
370372
"""
371-
if len(self._allreduce_futures) == 0:
373+
if len(self._allreduce_work) == 0:
372374
return
373375

374376
if self._stream is not None:
375377
assert self._stop_event is not None
376378
self._stop_event.synchronize()
377379
self._stop_event = None
378380

379-
self._allreduce_futures = []
381+
self._allreduce_work = []
380382

381383
@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
382384
def prepare_sync(self) -> None:
@@ -386,7 +388,7 @@ def prepare_sync(self) -> None:
386388
"""
387389
self._save_grads()
388390

389-
assert len(self._allreduce_futures) == 0
391+
assert len(self._allreduce_work) == 0
390392

391393
# Make sure tensors are available to `_stream`
392394
if self._stream is not None:
@@ -399,21 +401,28 @@ def prepare_sync(self) -> None:
399401
):
400402
self._average_grads()
401403

402-
for work in self._allreduce_futures:
403-
work.wait()
404-
405-
if self._stream is not None:
406-
self._stop_event = torch.cuda.Event()
407-
self._stop_event.record()
408-
409404
@torch.profiler.record_function("torchft::local_sgd::perform_sync")
410405
def perform_sync(self) -> bool:
411406
"""
412407
Overrides the sync method to wait for the scheduled allreduce to finish and
413408
steps using the outer optimizer.
414409
"""
415410
# Waiting for an allreduce before it has been sent is currently not supported.
416-
assert len(self._allreduce_futures) > 0
411+
assert len(self._allreduce_work) > 0
412+
413+
with (
414+
torch.cuda.stream(self._stream)
415+
if self._stream is not None
416+
else nullcontext()
417+
):
418+
for work in self._allreduce_work:
419+
# TODO: Setup proper stream dependency on the future
420+
# attached to this work
421+
work.wait()
422+
423+
if self._stream is not None:
424+
self._stop_event = torch.cuda.Event()
425+
self._stop_event.record()
417426

418427
self.wait()
419428

@@ -467,7 +476,8 @@ def _allreduce_per_param(self) -> None:
467476
work = self._manager.allreduce(
468477
self._grads[name], should_quantize=self.should_quantize
469478
)
470-
self._allreduce_futures.append(work)
479+
480+
self._allreduce_work.append(work)
471481

472482
def _bucketize_and_allreduce(
473483
self,
@@ -508,17 +518,25 @@ def _bucketize_and_allreduce(
508518
pack_offset += numel
509519
flat_index += 1
510520

511-
work = self._manager.allreduce(
512-
flat_buffer, should_quantize=self.should_quantize
513-
)
521+
with torch.cuda.stream(self._stream) if self._stream else nullcontext():
522+
work = self._manager.allreduce(
523+
flat_buffer, should_quantize=self.should_quantize
524+
)
514525

515526
def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
516-
nonlocal bucket_tensors, flat_buffer
517-
for t, pack_offset, numel in bucket_tensors:
518-
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
527+
with torch.cuda.stream(self._stream) if self._stream else nullcontext():
528+
nonlocal bucket_tensors, flat_buffer
529+
for t, pack_offset, numel in bucket_tensors:
530+
t.copy_(
531+
flat_buffer[pack_offset : pack_offset + numel].view_as(t)
532+
)
533+
534+
fut = work.get_future()
535+
# TODO(tushar00jain): We need to call work.wait() here to ensure callback
536+
# runs after work has been completed
537+
fut = fut.then(callback)
519538

520-
work = work.then(callback)
521-
self._allreduce_futures.append(work)
539+
self._allreduce_work.append(work)
522540

523541
offset += chunk_size
524542

torchft/local_sgd_test.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
import torch
1212
from parameterized import parameterized
1313
from torch import Tensor, nn, optim
14+
from torch.distributed.distributed_c10d import Work
1415
from torch.distributed.tensor import DTensor
1516

1617
from torchft.local_sgd import DiLoCo, LocalSGD, extract_local_tensor
1718
from torchft.manager import Manager
19+
from torchft.work import _DummyWork
1820

1921

2022
def create_manager() -> MagicMock:
@@ -26,6 +28,11 @@ def create_manager() -> MagicMock:
2628

2729
manager.errored.return_value = None
2830

31+
def mock_allreduce(tensor: torch.Tensor, should_quantize: bool = False) -> Work:
32+
return _DummyWork(tensor)
33+
34+
manager.allreduce.side_effect = mock_allreduce
35+
2936
return manager
3037

3138

@@ -66,7 +73,7 @@ class LocalSGDTest(TestCase):
6673
def test_local_sgd_healthy(self) -> None:
6774
model = SimpleModel()
6875
optimizer = optim.SGD(model.parameters())
69-
manager = create_autospec(Manager)
76+
manager = create_manager()
7077
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
7178
self.assertEqual(local_sgd._local_step, 0)
7279
inp = torch.rand(2, 3)
@@ -240,13 +247,9 @@ def test_bucketization_correctness(self) -> None:
240247
manager.should_commit.return_value = True
241248

242249
# Define fake allreduce: multiplies buffer by 2
243-
def fake_allreduce(
244-
tensor: Tensor, should_quantize: bool
245-
) -> torch.futures.Future[Tensor]:
250+
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
246251
tensor.mul_(2)
247-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
248-
fut.set_result(tensor)
249-
return fut
252+
return _DummyWork(tensor)
250253

251254
manager.allreduce.side_effect = fake_allreduce
252255

@@ -284,13 +287,9 @@ def test_gradient_correctness(self) -> None:
284287
manager.should_commit.return_value = True
285288

286289
# Define fake allreduce: multiplies buffer by 2
287-
def fake_allreduce(
288-
tensor: Tensor, should_quantize: bool
289-
) -> torch.futures.Future[Tensor]:
290+
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
290291
tensor.mul_(2)
291-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
292-
fut.set_result(tensor)
293-
return fut
292+
return _DummyWork(tensor)
294293

295294
manager.allreduce.side_effect = fake_allreduce
296295

torchft/manager.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"""
2727

2828
import concurrent.futures
29+
import copy
2930
import logging
3031
import os
3132
import socket
@@ -39,11 +40,12 @@
3940

4041
import torch
4142
from torch.distributed import ReduceOp, TCPStore
42-
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp
43+
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
4344

4445
from torchft._torchft import ManagerClient, ManagerServer
4546
from torchft.checkpointing import CheckpointTransport, HTTPTransport
4647
from torchft.futures import future_timeout
48+
from torchft.work import _DummyWork
4749

4850
if TYPE_CHECKING:
4951
from torchft.process_group import ProcessGroup
@@ -343,9 +345,7 @@ def shutdown(self, wait: bool = True) -> None:
343345
self._executor.shutdown(wait=wait)
344346

345347
@torch.profiler.record_function("torchft::manager::allreduce")
346-
def allreduce(
347-
self, tensor: torch.Tensor, should_quantize: bool = False
348-
) -> torch.futures.Future[torch.Tensor]:
348+
def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work:
349349
"""
350350
Fault tolerant allreduce the tensor and return a Future that will be completed when
351351
the tensor is ready.
@@ -367,7 +367,7 @@ def allreduce(
367367
if self.errored():
368368
fut = torch.futures.Future() # pyre-fixme[29]: not a function
369369
fut.set_result(tensor)
370-
return fut
370+
return _DummyWork(tensor)
371371

372372
self.wait_quorum()
373373
num_participants: int = self.num_participants()
@@ -380,13 +380,14 @@ def allreduce(
380380
# Run the allreduce async and save the work object so we can wait on
381381
# it later.
382382
if should_quantize and IS_TRITON_AVAILABLE:
383-
fut = allreduce_quantized(
383+
work = allreduce_quantized(
384384
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
385385
)
386386
else:
387387
work = self._pg.allreduce([tensor], ReduceOp.SUM)
388-
work.wait()
389-
fut = work.get_future()
388+
389+
# TODO(tushar00jain): Set up the stream dependency correctly
390+
fut = work.get_future()
390391

391392
stream: Optional[torch.cuda.Stream] = (
392393
torch.cuda.current_stream() if torch.cuda.is_available() else None
@@ -411,17 +412,15 @@ def callback(
411412
fut = fut.then(callback)
412413

413414
fut = self.wrap_future(fut, tensor)
414-
return fut
415+
return work
415416

416417
except Exception as e:
417418
self._logger.exception(
418419
f"got exception in all reduce -- skipping remaining: {e}"
419420
)
420421
self.report_error(e)
421422

422-
fut = torch.futures.Future() # pyre-fixme[29]: not a function
423-
fut.set_result(tensor)
424-
return fut
423+
return _DummyWork(tensor)
425424

426425
def report_error(self, e: Exception) -> None:
427426
"""
@@ -646,7 +645,7 @@ def _async_quorum(
646645
self._checkpoint_transport.send_checkpoint(
647646
dst_ranks=quorum.recover_dst_replica_ranks,
648647
step=max_step,
649-
state_dict=self._manager_state_dict(),
648+
state_dict=copy.deepcopy(self._manager_state_dict()),
650649
timeout=self._timeout,
651650
)
652651

torchft/manager_integ_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ def all_reduce_callback(
634634

635635
manager.start_quorum()
636636
t1 = torch.ones((1, 3), device=device)
637-
fut = manager.allreduce(t1)
638-
fut.wait()
637+
work = manager.allreduce(t1)
638+
work.get_future().wait()
639639
return t1
640640
return None

0 commit comments

Comments
 (0)