Skip to content

Commit 538ea8a

Browse files
committed
option 1 - use block_current to overlap compute/communication
1 parent 2b43bfc commit 538ea8a

File tree

3 files changed

+4
-7
lines changed

3 files changed

+4
-7
lines changed

torchft/local_sgd.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,7 @@ def perform_sync(self) -> bool:
416416
else nullcontext()
417417
):
418418
for work in self._allreduce_work:
419-
# TODO: Setup proper stream dependency on the future
420-
# attached to this work
421-
work.wait()
419+
work.get_future().wait()
422420

423421
if self._stream is not None:
424422
self._stop_event = torch.cuda.Event()
@@ -522,6 +520,7 @@ def _bucketize_and_allreduce(
522520
work = self._manager.allreduce(
523521
flat_buffer, should_quantize=self.should_quantize
524522
)
523+
work.block_current_stream()
525524

526525
def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
527526
with torch.cuda.stream(self._stream) if self._stream else nullcontext():
@@ -532,8 +531,6 @@ def callback(fut: torch.futures.Future[torch.Tensor]) -> None:
532531
)
533532

534533
fut = work.get_future()
535-
# TODO(tushar00jain): We need to call work.wait() here to ensure callback
536-
# runs after work has been completed
537534
fut = fut.then(callback)
538535

539536
self._allreduce_work.append(work)

torchft/manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ def allreduce(self, tensor: torch.Tensor, should_quantize: bool = False) -> Work
386386
else:
387387
work = self._pg.allreduce([tensor], ReduceOp.SUM)
388388

389-
# TODO(tushar00jain): Set up the stream dependency correctly
389+
work.block_current_stream()
390390
fut = work.get_future()
391391

392392
stream: Optional[torch.cuda.Stream] = (

torchft/process_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1072,7 +1072,7 @@ def wait(self, timeout: Optional[timedelta] = None) -> bool:
10721072
if timeout is not None:
10731073
self._work.wait(timeout)
10741074
else:
1075-
self._work.wait()
1075+
self._work.block_current_stream()
10761076
except Exception as e:
10771077
self._manager.report_error(e)
10781078

0 commit comments

Comments
 (0)