Skip to content

Commit adb94c2

Browse files
committed
integrate quantization in manager
Summary: - add back support for quantized allreduce in manager - change return types to be consistent with pg allreduce ## Test Plan - tested the changes with nccl pg - synchronize on recovery stream sometimes makes the cpu block on collective (probably because some callback gets scheduled on the recovery stream? we need to remove synchronizing on recovery stream when there is no need to) - calling `work.wait` returned by baby nccl pg makes the cpu block on the collective (plan to make it support async operations in a future diff) ## Without Quantization <img width="1188" alt="image" src="https://github.com/user-attachments/assets/8f8dd694-a972-4bc6-96a0-8a79627a4d5d" /> ## With Quantization <img width="1123" alt="image" src="https://github.com/user-attachments/assets/b54288a3-9727-4956-89e7-c8b8775a98aa" />
1 parent 095e418 commit adb94c2

File tree

5 files changed

+125
-75
lines changed

5 files changed

+125
-75
lines changed

torchft/collectives.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -135,21 +135,23 @@ def allocate_reduce_scatter_output(
135135
return tensor, padded_sizes
136136

137137

138-
class _QuantizedOpFuture(Future[None]):
138+
class _QuantizedOpFuture(Future[list[torch.Tensor]]):
139139
def __init__(
140140
self,
141141
sync_stream: cuda.Stream,
142142
keep_alive_tensors: list[torch.Tensor],
143+
return_tensors: list[torch.Tensor],
143144
) -> None:
144145
super().__init__()
145146
self._sync_stream = sync_stream
146147
self._keep_alive_tensors = keep_alive_tensors
148+
self._return_tensors = return_tensors
147149

148-
def wait(self) -> None:
150+
def wait(self) -> list[torch.Tensor]:
149151
# Wait for the synchronization to complete.
150152
cuda.current_stream().wait_stream(self._sync_stream)
151153
# Clean up intermediate buffers.
152-
del self._keep_alive_tensors
154+
return self._return_tensors
153155

154156

155157
def reduce_scatter_quantized(
@@ -276,6 +278,7 @@ def reduce_scatter_quantized(
276278
quantized_inputs,
277279
quantized_inputs_out,
278280
],
281+
[output],
279282
)
280283

281284

@@ -284,7 +287,7 @@ def allreduce_quantized(
284287
opts: AllreduceOptions | ReduceOp,
285288
process_group: "ProcessGroup",
286289
sync_stream: cuda.Stream | None = None,
287-
) -> Future[None]:
290+
) -> Future[list[torch.Tensor]]:
288291
"""
289292
Performs a quantized all-reduce operation on a list of tensors.
290293
@@ -334,7 +337,7 @@ def allreduce_quantized(
334337
)
335338

336339
rank = process_group.rank()
337-
world_size = process_group.size()
340+
world_size: int = process_group.size()
338341

339342
if sync_stream is None:
340343
sync_stream = cuda.Stream()
@@ -346,7 +349,7 @@ def allreduce_quantized(
346349
with cuda.stream(sync_stream):
347350
# Quantize tensoers and compute their scales, all inlined in the
348351
# output tensor.
349-
quantized_tensors = fused_quantize_into_fp8(tensors, world_size)
352+
quantized_tensors: torch.Tensor = fused_quantize_into_fp8(tensors, world_size)
350353

351354
# Allocate output tensor where all-reduce results will be stored
352355
quantized_tensors_out = torch.zeros_like(quantized_tensors)
@@ -370,20 +373,22 @@ def allreduce_quantized(
370373
)
371374

372375
# Collect reduced chunks from other ranks.
373-
process_group.allgather_into_tensor_coalesced(
376+
work = process_group.allgather_into_tensor_coalesced(
374377
[quantized_tensors.view(world_size, -1)],
375378
[torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
376379
_to_allgather_options(allreduce_opts),
377-
).wait()
380+
)
381+
work.wait()
382+
fut = work.get_future()
378383

379-
# Dequantize and copy to output buffer.
380-
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
384+
def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]:
385+
# Dequantize and copy to output buffer.
386+
nonlocal tensors, quantized_tensors, world_size, sync_stream
381387

382-
# pyre-ignore[29]
383-
return _QuantizedOpFuture(
384-
sync_stream,
385-
[
386-
quantized_tensors,
387-
quantized_tensors_out,
388-
],
389-
)
388+
with torch.cuda.stream(sync_stream):
389+
# Dequantize the result back to the original precision
390+
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
391+
return tensors
392+
393+
fut = fut.then(callback)
394+
return fut

torchft/futures.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import sys
44
import threading
55
import time
6-
from contextlib import contextmanager
6+
from contextlib import contextmanager, nullcontext
77
from datetime import timedelta
88
from typing import Callable, Generator, Optional, TypeVar
99
from unittest.mock import Mock
@@ -162,17 +162,22 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
162162
handle,
163163
)
164164

165+
stream: Optional[torch.cuda.Stream] = (
166+
torch.cuda.current_stream() if torch.cuda.is_available() else None
167+
)
168+
165169
def callback(fut: Future[T]) -> None:
166-
handle.cancel()
167-
try:
168-
timed_fut.set_result(fut.wait())
169-
except Exception as e:
170+
with torch.cuda.stream(stream) if stream is not None else nullcontext():
171+
handle.cancel()
170172
try:
171-
# this can throw if the future is already done
172-
# pyre-fixme[6]: e is not T
173-
timed_fut.set_exception(e)
174-
except Exception:
175-
pass
173+
timed_fut.set_result(fut.wait())
174+
except Exception as e:
175+
try:
176+
# this can throw if the future is already done
177+
# pyre-fixme[6]: e is not T
178+
timed_fut.set_exception(e)
179+
except Exception:
180+
pass
176181

177182
fut.add_done_callback(callback)
178183
return timed_fut

torchft/local_sgd.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,18 @@ def __init__(
203203
torch.cuda.Stream() if torch.cuda.is_available() else None
204204
)
205205

206+
# Recorded on `_stream` to wait for allreduce to finish
207+
self._stop_event: Optional[torch.cuda.Event] = None
208+
206209
if bucket_cap_mb is not None:
207210
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
208211

209212
self.use_bucketization = use_bucketization
210213
self.should_quantize = should_quantize
211214

215+
self._grads: Dict[str, torch.Tensor] = {}
212216
self.original_parameters: Dict[str, torch.Tensor] = {}
217+
213218
for name, p in self._model_fragment.named_parameters():
214219
if isinstance(p, DTensor):
215220
p = extract_local_tensor(p.data)
@@ -252,17 +257,30 @@ def restore_parameters(self) -> None:
252257
else:
253258
p.data.copy_(self.original_parameters[name], non_blocking=False)
254259

260+
def _set_grads(self) -> None:
261+
"""
262+
Sets the gradients of the model fragment from the allreduce result
263+
"""
264+
for name, p in self._model_fragment.named_parameters():
265+
if isinstance(p, DTensor):
266+
p.grad._local_tensor = self._grads[name]
267+
else:
268+
p.grad = self._grads[name]
269+
270+
del self._grads[name]
271+
255272
@torch.profiler.record_function("torchft::local_sgd::wait")
256273
def wait(self) -> None:
257274
"""
258275
Waits for the previously scheduled allreduce to finish
259276
"""
260-
261-
for work in self._allreduce_futures:
262-
work.wait()
277+
if len(self._allreduce_futures) == 0:
278+
return
263279

264280
if self._stream is not None:
265-
self._stream.synchronize()
281+
assert self._stop_event is not None
282+
self._stop_event.synchronize()
283+
self._stop_event = None
266284

267285
self._allreduce_futures = []
268286

@@ -286,24 +304,33 @@ def prepare_sync(self) -> None:
286304
Calculate the pseugradient, average them across the manager group and starts
287305
allreduce on the pseudo-gradients but doesn't wait for it to finish.
288306
"""
307+
# Set the .grad field of each parameter to its pseudogradient
308+
for name, p in self._model_fragment.named_parameters():
309+
local_param = extract_local_tensor(p.data)
310+
pseudogradient = local_param - self.original_parameters[name].to(p.device)
311+
if isinstance(p, DTensor):
312+
self._grads[name] = pseudogradient
313+
else:
314+
self._grads[name] = pseudogradient
315+
316+
# Make sure tensors are available to `_stream`
317+
if self._stream is not None:
318+
self._stream.wait_stream(torch.cuda.current_stream())
319+
289320
with (
290321
torch.cuda.stream(self._stream)
291322
if self._stream is not None
292323
else nullcontext()
293324
):
294-
# Set the .grad field of each parameter to its pseudogradient
295-
for name, p in self._model_fragment.named_parameters():
296-
local_param = extract_local_tensor(p.data)
297-
pseudogradient = local_param - self.original_parameters[name].to(
298-
p.device
299-
)
300-
if isinstance(p, DTensor):
301-
p.grad._local_tensor = pseudogradient
302-
else:
303-
p.grad = pseudogradient
304-
305325
self._average_grads()
306326

327+
for work in self._allreduce_futures:
328+
work.wait()
329+
330+
if self._stream is not None:
331+
self._stop_event = torch.cuda.Event()
332+
self._stop_event.record()
333+
307334
@torch.profiler.record_function("torchft::local_sgd::perform_sync")
308335
def perform_sync(self) -> bool:
309336
"""
@@ -322,6 +349,7 @@ def perform_sync(self) -> bool:
322349

323350
if should_commit:
324351
# Use the outer optimizer to update the model parameters
352+
self._set_grads()
325353
self._outer_optimizer.step()
326354
self.save_parameters()
327355
self._outer_optimizer.zero_grad()
@@ -341,16 +369,16 @@ def _average_grads(self) -> None:
341369

342370
def _allreduce_per_param(self) -> None:
343371
"""Performs allreduce on each gradient tensor separately (original method)."""
344-
for p in self._model_fragment.parameters():
372+
for name, p in self._model_fragment.named_parameters():
345373
# Perform allreduce on the pseudogradients
346374
assert p.grad is not None
347375
if isinstance(p, DTensor):
348376
work = self._manager.allreduce(
349-
p.grad._local_tensor, should_quantize=self.should_quantize
377+
self._grads[name], should_quantize=self.should_quantize
350378
)
351379
else:
352380
work = self._manager.allreduce(
353-
p.grad, should_quantize=self.should_quantize
381+
self._grads[name], should_quantize=self.should_quantize
354382
)
355383
self._allreduce_futures.append(work)
356384

@@ -609,7 +637,6 @@ def _step_post_hook(
609637
#
610638
# Both of them will fail because Node A didn't send fragment 2
611639
# and Node B didn't send fragment 1.
612-
step = 0
613640
with self._lock:
614641
self._local_step += 1
615642
step = self._local_step

torchft/manager.py

Lines changed: 38 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,9 @@ def allreduce(
332332
# Run the allreduce async and save the work object so we can wait on
333333
# it later.
334334
if should_quantize and IS_TRITON_AVAILABLE:
335-
assert False, "allreduce_quantized is not supported yet"
336-
# TODO: Support `allreduce_quantized`
337-
# fut = allreduce_quantized([tensor], ReduceOp.SUM, self._pg)
335+
fut = allreduce_quantized(
336+
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
337+
)
338338
else:
339339
work = self._pg.allreduce([tensor], ReduceOp.SUM)
340340
fut = work.get_future()
@@ -345,22 +345,22 @@ def allreduce(
345345

346346
# schedule grad normalization as a continuation
347347
# on the Future
348+
@torch.profiler.record_function("torchft::manager::allreduce::callback")
348349
def callback(
349350
fut: torch.futures.Future[List[torch.Tensor]],
350351
) -> torch.Tensor:
351352
nonlocal tensor, stream
352353

353-
# check for exceptions
354-
fut.value()
355-
356-
tensor /= self.num_participants()
354+
# change the stream to avoid making the callback stream
355+
# dependent on process group stream running the allreduce
356+
with torch.cuda.stream(stream) if stream is not None else nullcontext():
357+
fut.value()
358+
tensor /= self.num_participants()
357359

358-
if stream is not None:
359-
stream.wait_stream(torch.cuda.current_stream())
360-
361-
return tensor
360+
return tensor
362361

363362
fut = fut.then(callback)
363+
364364
fut = self.wrap_future(fut, tensor)
365365
return fut
366366

@@ -409,26 +409,31 @@ def wrap_future(
409409
Args:
410410
fut: the Future to wrap
411411
default: the default value to complete the Future with if an error occurs
412+
stream: the stream to run the continuation on, if None, the default stream will be used
412413
timeout: the timeout for the Future, if None, the manager's timeout will be used
413414
"""
414415

415-
# add a timeout to the future
416416
fut = future_timeout(fut, timeout or self._timeout)
417417

418+
stream: Optional[torch.cuda.Stream] = (
419+
torch.cuda.current_stream() if torch.cuda.is_available() else None
420+
)
421+
418422
# schedule error handling as a continuation on the Future
419423
def callback(
420424
fut: torch.futures.Future[T],
421425
) -> T:
422-
nonlocal default
426+
nonlocal default, stream
423427

424-
try:
425-
return fut.value()
426-
except Exception as e:
427-
self._logger.exception(
428-
f"got exception in future -- skipping remaining: {e}"
429-
)
430-
self.report_error(e)
431-
return default
428+
with torch.cuda.stream(stream) if stream is not None else nullcontext():
429+
try:
430+
return fut.value()
431+
except Exception as e:
432+
self._logger.exception(
433+
f"got exception in future -- skipping remaining: {e}"
434+
)
435+
self.report_error(e)
436+
return default
432437

433438
fut = fut.then(callback)
434439
return fut
@@ -488,6 +493,7 @@ def start_quorum(
488493
# and don't need to zero_grad
489494
self._healing = False
490495

496+
@torch.profiler.record_function("torchft::manager::wait_quorum")
491497
def wait_quorum(self) -> None:
492498
"""
493499
Wait for the quorum to complete.
@@ -696,11 +702,17 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
696702
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
697703
"""
698704
# make sure recovery is complete before committing
699-
if self._recovery_stream is not None:
700-
self._recovery_stream.synchronize()
701-
702-
if torch.cuda.is_available():
703-
torch.cuda.current_stream().synchronize()
705+
with torch.profiler.record_function(
706+
"torchft::manager::should_commmit::recovery_stream::synchronize"
707+
):
708+
if self._recovery_stream is not None:
709+
self._recovery_stream.synchronize()
710+
711+
with torch.profiler.record_function(
712+
"torchft::manager::should_commit::current_stream::synchronize"
713+
):
714+
if torch.cuda.is_available():
715+
torch.cuda.current_stream().synchronize()
704716

705717
if err := self._pg.errored():
706718
self.report_error(err)

0 commit comments

Comments
 (0)