Skip to content

Commit acef15d

Browse files
committed
support async in nccl pg
Summary: - set the same stream as the one used for work in future continuations so that random streams don't depend on pg stream (this can make these streams dependent on the allreduce stream) - wait on the work sent to pg's immediately on the fragment streams (used for allreduce) to make them depend on the pg stream and so that they don't depend on any future work that's submitted to those streams - copy grads before allreduce so that the inner optimization can use those and it doesn't create a dependency between the default stream and the pg stream - add back support for quantized allreduce in manager - change return types to be consistent with pg allreduce - the returned future from quantization collectives hangs (likely because set_result is not called?) so changed it to return the future directly from the pg Test Plan: - tested the changes with nccl pg - synchronize on recovery stream sometimes makes the cpu block on collective (probably because some callback gets scheduled on the recovery stream? we need to remove synchronizing on recovery stream when there is no need to) - calling `work.wait` returned by baby nccl pg makes the cpu block on the collective (because 2 contexts can't overlap?) - pg gloo needs us to call `future.wait` in the sync phase instead of the prepare phase, so we probably need a different wrapper - same for baby gloo pg > Without Quantization <img width="1188" alt="image" src="https://github.com/user-attachments/assets/8f8dd694-a972-4bc6-96a0-8a79627a4d5d" /> > With Quantization <img width="1123" alt="image" src="https://github.com/user-attachments/assets/b54288a3-9727-4956-89e7-c8b8775a98aa" />
1 parent 095e418 commit acef15d

File tree

5 files changed

+124
-74
lines changed

5 files changed

+124
-74
lines changed

torchft/collectives.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -135,21 +135,24 @@ def allocate_reduce_scatter_output(
135135
return tensor, padded_sizes
136136

137137

138-
class _QuantizedOpFuture(Future[None]):
138+
class _QuantizedOpFuture(Future[list[torch.Tensor]]):
139139
def __init__(
140140
self,
141141
sync_stream: cuda.Stream,
142142
keep_alive_tensors: list[torch.Tensor],
143+
return_tensors: list[torch.Tensor],
143144
) -> None:
144145
super().__init__()
145146
self._sync_stream = sync_stream
146147
self._keep_alive_tensors = keep_alive_tensors
148+
self._return_tensors = return_tensors
147149

148-
def wait(self) -> None:
150+
def wait(self) -> list[torch.Tensor]:
149151
# Wait for the synchronization to complete.
150152
cuda.current_stream().wait_stream(self._sync_stream)
151153
# Clean up intermediate buffers.
152154
del self._keep_alive_tensors
155+
return self._return_tensors
153156

154157

155158
def reduce_scatter_quantized(
@@ -276,6 +279,7 @@ def reduce_scatter_quantized(
276279
quantized_inputs,
277280
quantized_inputs_out,
278281
],
282+
[output],
279283
)
280284

281285

@@ -284,7 +288,7 @@ def allreduce_quantized(
284288
opts: AllreduceOptions | ReduceOp,
285289
process_group: "ProcessGroup",
286290
sync_stream: cuda.Stream | None = None,
287-
) -> Future[None]:
291+
) -> Future[list[torch.Tensor]]:
288292
"""
289293
Performs a quantized all-reduce operation on a list of tensors.
290294
@@ -334,7 +338,7 @@ def allreduce_quantized(
334338
)
335339

336340
rank = process_group.rank()
337-
world_size = process_group.size()
341+
world_size: int = process_group.size()
338342

339343
if sync_stream is None:
340344
sync_stream = cuda.Stream()
@@ -346,7 +350,7 @@ def allreduce_quantized(
346350
with cuda.stream(sync_stream):
347351
# Quantize tensoers and compute their scales, all inlined in the
348352
# output tensor.
349-
quantized_tensors = fused_quantize_into_fp8(tensors, world_size)
353+
quantized_tensors: torch.Tensor = fused_quantize_into_fp8(tensors, world_size)
350354

351355
# Allocate output tensor where all-reduce results will be stored
352356
quantized_tensors_out = torch.zeros_like(quantized_tensors)
@@ -370,20 +374,22 @@ def allreduce_quantized(
370374
)
371375

372376
# Collect reduced chunks from other ranks.
373-
process_group.allgather_into_tensor_coalesced(
377+
work = process_group.allgather_into_tensor_coalesced(
374378
[quantized_tensors.view(world_size, -1)],
375379
[torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
376380
_to_allgather_options(allreduce_opts),
377-
).wait()
381+
)
382+
work.wait()
383+
fut = work.get_future()
378384

379-
# Dequantize and copy to output buffer.
380-
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
385+
def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]:
386+
# Dequantize and copy to output buffer.
387+
nonlocal tensors, quantized_tensors, world_size, sync_stream
381388

382-
# pyre-ignore[29]
383-
return _QuantizedOpFuture(
384-
sync_stream,
385-
[
386-
quantized_tensors,
387-
quantized_tensors_out,
388-
],
389-
)
389+
with torch.cuda.stream(sync_stream):
390+
# Dequantize the result back to the original precision
391+
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
392+
return tensors
393+
394+
fut = fut.then(callback)
395+
return fut

torchft/futures.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
import threading
55
import time
6-
from contextlib import contextmanager
6+
from contextlib import contextmanager, nullcontext
77
from datetime import timedelta
88
from typing import Callable, Generator, Optional, TypeVar
99
from unittest.mock import Mock
@@ -162,17 +162,22 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
162162
handle,
163163
)
164164

165+
stream: Optional[torch.cuda.Stream] = (
166+
torch.cuda.current_stream() if torch.cuda.is_available() else None
167+
)
168+
165169
def callback(fut: Future[T]) -> None:
166-
handle.cancel()
167-
try:
168-
timed_fut.set_result(fut.wait())
169-
except Exception as e:
170+
with torch.cuda.stream(stream) if stream is not None else nullcontext():
171+
handle.cancel()
170172
try:
171-
# this can throw if the future is already done
172-
# pyre-fixme[6]: e is not T
173-
timed_fut.set_exception(e)
174-
except Exception:
175-
pass
173+
timed_fut.set_result(fut.wait())
174+
except Exception as e:
175+
try:
176+
# this can throw if the future is already done
177+
# pyre-fixme[6]: e is not T
178+
timed_fut.set_exception(e)
179+
except Exception:
180+
pass
176181

177182
fut.add_done_callback(callback)
178183
return timed_fut

torchft/local_sgd.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,18 @@ def __init__(
203203
torch.cuda.Stream() if torch.cuda.is_available() else None
204204
)
205205

206+
# Recorded on `_stream` to wait for allreduce to finish
207+
self._stop_event: Optional[torch.cuda.Event] = None
208+
206209
if bucket_cap_mb is not None:
207210
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
208211

209212
self.use_bucketization = use_bucketization
210213
self.should_quantize = should_quantize
211214

215+
self._grads: Dict[str, torch.Tensor] = {}
212216
self.original_parameters: Dict[str, torch.Tensor] = {}
217+
213218
for name, p in self._model_fragment.named_parameters():
214219
if isinstance(p, DTensor):
215220
p = extract_local_tensor(p.data)
@@ -252,17 +257,30 @@ def restore_parameters(self) -> None:
252257
else:
253258
p.data.copy_(self.original_parameters[name], non_blocking=False)
254259

260+
def _set_grads(self) -> None:
261+
"""
262+
Sets the gradients of the model fragment from the allreduce result
263+
"""
264+
for name, p in self._model_fragment.named_parameters():
265+
if isinstance(p, DTensor):
266+
p.grad._local_tensor = self._grads[name]
267+
else:
268+
p.grad = self._grads[name]
269+
270+
del self._grads[name]
271+
255272
@torch.profiler.record_function("torchft::local_sgd::wait")
256273
def wait(self) -> None:
257274
"""
258275
Waits for the previously scheduled allreduce to finish
259276
"""
260-
261-
for work in self._allreduce_futures:
262-
work.wait()
277+
if len(self._allreduce_futures) == 0:
278+
return
263279

264280
if self._stream is not None:
265-
self._stream.synchronize()
281+
assert self._stop_event is not None
282+
self._stop_event.synchronize()
283+
self._stop_event = None
266284

267285
self._allreduce_futures = []
268286

@@ -286,24 +304,33 @@ def prepare_sync(self) -> None:
286304
Calculate the pseugradient, average them across the manager group and starts
287305
allreduce on the pseudo-gradients but doesn't wait for it to finish.
288306
"""
307+
# Set the .grad field of each parameter to its pseudogradient
308+
for name, p in self._model_fragment.named_parameters():
309+
local_param = extract_local_tensor(p.data)
310+
pseudogradient = local_param - self.original_parameters[name].to(p.device)
311+
if isinstance(p, DTensor):
312+
self._grads[name] = pseudogradient
313+
else:
314+
self._grads[name] = pseudogradient
315+
316+
# Make sure tensors are available to `_stream`
317+
if self._stream is not None:
318+
self._stream.wait_stream(torch.cuda.current_stream())
319+
289320
with (
290321
torch.cuda.stream(self._stream)
291322
if self._stream is not None
292323
else nullcontext()
293324
):
294-
# Set the .grad field of each parameter to its pseudogradient
295-
for name, p in self._model_fragment.named_parameters():
296-
local_param = extract_local_tensor(p.data)
297-
pseudogradient = local_param - self.original_parameters[name].to(
298-
p.device
299-
)
300-
if isinstance(p, DTensor):
301-
p.grad._local_tensor = pseudogradient
302-
else:
303-
p.grad = pseudogradient
304-
305325
self._average_grads()
306326

327+
for work in self._allreduce_futures:
328+
work.wait()
329+
330+
if self._stream is not None:
331+
self._stop_event = torch.cuda.Event()
332+
self._stop_event.record()
333+
307334
@torch.profiler.record_function("torchft::local_sgd::perform_sync")
308335
def perform_sync(self) -> bool:
309336
"""
@@ -322,6 +349,7 @@ def perform_sync(self) -> bool:
322349

323350
if should_commit:
324351
# Use the outer optimizer to update the model parameters
352+
self._set_grads()
325353
self._outer_optimizer.step()
326354
self.save_parameters()
327355
self._outer_optimizer.zero_grad()
@@ -341,16 +369,16 @@ def _average_grads(self) -> None:
341369

342370
def _allreduce_per_param(self) -> None:
343371
"""Performs allreduce on each gradient tensor separately (original method)."""
344-
for p in self._model_fragment.parameters():
372+
for name, p in self._model_fragment.named_parameters():
345373
# Perform allreduce on the pseudogradients
346374
assert p.grad is not None
347375
if isinstance(p, DTensor):
348376
work = self._manager.allreduce(
349-
p.grad._local_tensor, should_quantize=self.should_quantize
377+
self._grads[name], should_quantize=self.should_quantize
350378
)
351379
else:
352380
work = self._manager.allreduce(
353-
p.grad, should_quantize=self.should_quantize
381+
self._grads[name], should_quantize=self.should_quantize
354382
)
355383
self._allreduce_futures.append(work)
356384

@@ -609,7 +637,6 @@ def _step_post_hook(
609637
#
610638
# Both of them will fail because Node A didn't send fragment 2
611639
# and Node B didn't send fragment 1.
612-
step = 0
613640
with self._lock:
614641
self._local_step += 1
615642
step = self._local_step

torchft/manager.py

Lines changed: 37 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,9 @@ def allreduce(
332332
# Run the allreduce async and save the work object so we can wait on
333333
# it later.
334334
if should_quantize and IS_TRITON_AVAILABLE:
335-
assert False, "allreduce_quantized is not supported yet"
336-
# TODO: Support `allreduce_quantized`
337-
# fut = allreduce_quantized([tensor], ReduceOp.SUM, self._pg)
335+
fut = allreduce_quantized(
336+
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
337+
)
338338
else:
339339
work = self._pg.allreduce([tensor], ReduceOp.SUM)
340340
fut = work.get_future()
@@ -345,22 +345,22 @@ def allreduce(
345345

346346
# schedule grad normalization as a continuation
347347
# on the Future
348+
@torch.profiler.record_function("torchft::manager::allreduce::callback")
348349
def callback(
349350
fut: torch.futures.Future[List[torch.Tensor]],
350351
) -> torch.Tensor:
351352
nonlocal tensor, stream
352353

353-
# check for exceptions
354-
fut.value()
355-
356-
tensor /= self.num_participants()
354+
# change the stream to avoid making the callback stream
355+
# dependent on process group stream running the allreduce
356+
with torch.cuda.stream(stream) if stream is not None else nullcontext():
357+
fut.value()
358+
tensor /= self.num_participants()
357359

358-
if stream is not None:
359-
stream.wait_stream(torch.cuda.current_stream())
360-
361-
return tensor
360+
return tensor
362361

363362
fut = fut.then(callback)
363+
364364
fut = self.wrap_future(fut, tensor)
365365
return fut
366366

@@ -412,23 +412,27 @@ def wrap_future(
412412
timeout: the timeout for the Future, if None, the manager's timeout will be used
413413
"""
414414

415-
# add a timeout to the future
416415
fut = future_timeout(fut, timeout or self._timeout)
417416

417+
stream: Optional[torch.cuda.Stream] = (
418+
torch.cuda.current_stream() if torch.cuda.is_available() else None
419+
)
420+
418421
# schedule error handling as a continuation on the Future
419422
def callback(
420423
fut: torch.futures.Future[T],
421424
) -> T:
422-
nonlocal default
425+
nonlocal default, stream
423426

424-
try:
425-
return fut.value()
426-
except Exception as e:
427-
self._logger.exception(
428-
f"got exception in future -- skipping remaining: {e}"
429-
)
430-
self.report_error(e)
431-
return default
427+
with torch.cuda.stream(stream) if stream is not None else nullcontext():
428+
try:
429+
return fut.value()
430+
except Exception as e:
431+
self._logger.exception(
432+
f"got exception in future -- skipping remaining: {e}"
433+
)
434+
self.report_error(e)
435+
return default
432436

433437
fut = fut.then(callback)
434438
return fut
@@ -488,6 +492,7 @@ def start_quorum(
488492
# and don't need to zero_grad
489493
self._healing = False
490494

495+
@torch.profiler.record_function("torchft::manager::wait_quorum")
491496
def wait_quorum(self) -> None:
492497
"""
493498
Wait for the quorum to complete.
@@ -696,11 +701,17 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
696701
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
697702
"""
698703
# make sure recovery is complete before committing
699-
if self._recovery_stream is not None:
700-
self._recovery_stream.synchronize()
701-
702-
if torch.cuda.is_available():
703-
torch.cuda.current_stream().synchronize()
704+
with torch.profiler.record_function(
705+
"torchft::manager::should_commmit::recovery_stream::synchronize"
706+
):
707+
if self._recovery_stream is not None:
708+
self._recovery_stream.synchronize()
709+
710+
with torch.profiler.record_function(
711+
"torchft::manager::should_commit::current_stream::synchronize"
712+
):
713+
if torch.cuda.is_available():
714+
torch.cuda.current_stream().synchronize()
704715

705716
if err := self._pg.errored():
706717
self.report_error(err)

0 commit comments

Comments
 (0)