Skip to content

add debugability for baby pg #213

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 23 additions & 17 deletions torchft/collectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,21 +135,24 @@ def allocate_reduce_scatter_output(
return tensor, padded_sizes


class _QuantizedOpFuture(Future[None]):
class _QuantizedOpFuture(Future[list[torch.Tensor]]):
def __init__(
self,
sync_stream: cuda.Stream,
keep_alive_tensors: list[torch.Tensor],
return_tensors: list[torch.Tensor],
) -> None:
super().__init__()
self._sync_stream = sync_stream
self._keep_alive_tensors = keep_alive_tensors
self._return_tensors = return_tensors

def wait(self) -> None:
def wait(self) -> list[torch.Tensor]:
# Wait for the synchronization to complete.
cuda.current_stream().wait_stream(self._sync_stream)
# Clean up intermediate buffers.
del self._keep_alive_tensors
return self._return_tensors


def reduce_scatter_quantized(
Expand Down Expand Up @@ -276,6 +279,7 @@ def reduce_scatter_quantized(
quantized_inputs,
quantized_inputs_out,
],
[output],
)


Expand All @@ -284,7 +288,7 @@ def allreduce_quantized(
opts: AllreduceOptions | ReduceOp,
process_group: "ProcessGroup",
sync_stream: cuda.Stream | None = None,
) -> Future[None]:
) -> Future[list[torch.Tensor]]:
"""
Performs a quantized all-reduce operation on a list of tensors.

Expand Down Expand Up @@ -334,7 +338,7 @@ def allreduce_quantized(
)

rank = process_group.rank()
world_size = process_group.size()
world_size: int = process_group.size()

if sync_stream is None:
sync_stream = cuda.Stream()
Expand All @@ -346,7 +350,7 @@ def allreduce_quantized(
with cuda.stream(sync_stream):
# Quantize tensoers and compute their scales, all inlined in the
# output tensor.
quantized_tensors = fused_quantize_into_fp8(tensors, world_size)
quantized_tensors: torch.Tensor = fused_quantize_into_fp8(tensors, world_size)

# Allocate output tensor where all-reduce results will be stored
quantized_tensors_out = torch.zeros_like(quantized_tensors)
Expand All @@ -370,20 +374,22 @@ def allreduce_quantized(
)

# Collect reduced chunks from other ranks.
process_group.allgather_into_tensor_coalesced(
work = process_group.allgather_into_tensor_coalesced(
[quantized_tensors.view(world_size, -1)],
[torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
_to_allgather_options(allreduce_opts),
).wait()
)
work.wait()
fut = work.get_future()

# Dequantize and copy to output buffer.
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]:
# Dequantize and copy to output buffer.
nonlocal tensors, quantized_tensors, world_size, sync_stream

# pyre-ignore[29]
return _QuantizedOpFuture(
sync_stream,
[
quantized_tensors,
quantized_tensors_out,
],
)
with torch.cuda.stream(sync_stream):
# Dequantize the result back to the original precision
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
return tensors

fut = fut.then(callback)
return fut
25 changes: 15 additions & 10 deletions torchft/futures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import sys
import threading
import time
from contextlib import contextmanager
from contextlib import contextmanager, nullcontext
from datetime import timedelta
from typing import Callable, Generator, Optional, TypeVar
from unittest.mock import Mock
Expand Down Expand Up @@ -162,17 +162,22 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
handle,
)

stream: Optional[torch.cuda.Stream] = (
torch.cuda.current_stream() if torch.cuda.is_available() else None
)

def callback(fut: Future[T]) -> None:
handle.cancel()
try:
timed_fut.set_result(fut.wait())
except Exception as e:
with torch.cuda.stream(stream) if stream is not None else nullcontext():
handle.cancel()
try:
# this can throw if the future is already done
# pyre-fixme[6]: e is not T
timed_fut.set_exception(e)
except Exception:
pass
timed_fut.set_result(fut.wait())
except Exception as e:
try:
# this can throw if the future is already done
# pyre-fixme[6]: e is not T
timed_fut.set_exception(e)
except Exception:
pass

fut.add_done_callback(callback)
return timed_fut
Expand Down
65 changes: 46 additions & 19 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,18 @@ def __init__(
torch.cuda.Stream() if torch.cuda.is_available() else None
)

# Recorded on `_stream` to wait for allreduce to finish
self._stop_event: Optional[torch.cuda.Event] = None

if bucket_cap_mb is not None:
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)

self.use_bucketization = use_bucketization
self.should_quantize = should_quantize

self._grads: Dict[str, torch.Tensor] = {}
self.original_parameters: Dict[str, torch.Tensor] = {}

for name, p in self._model_fragment.named_parameters():
if isinstance(p, DTensor):
p = extract_local_tensor(p.data)
Expand Down Expand Up @@ -252,17 +257,30 @@ def restore_parameters(self) -> None:
else:
p.data.copy_(self.original_parameters[name], non_blocking=False)

def _set_grads(self) -> None:
"""
Sets the gradients of the model fragment from the allreduce result
"""
for name, p in self._model_fragment.named_parameters():
if isinstance(p, DTensor):
p.grad._local_tensor = self._grads[name]
else:
p.grad = self._grads[name]

del self._grads[name]

@torch.profiler.record_function("torchft::local_sgd::wait")
def wait(self) -> None:
"""
Waits for the previously scheduled allreduce to finish
"""

for work in self._allreduce_futures:
work.wait()
if len(self._allreduce_futures) == 0:
return

if self._stream is not None:
self._stream.synchronize()
assert self._stop_event is not None
self._stop_event.synchronize()
self._stop_event = None

self._allreduce_futures = []

Expand All @@ -286,24 +304,33 @@ def prepare_sync(self) -> None:
Calculate the pseugradient, average them across the manager group and starts
allreduce on the pseudo-gradients but doesn't wait for it to finish.
"""
# Set the .grad field of each parameter to its pseudogradient
for name, p in self._model_fragment.named_parameters():
local_param = extract_local_tensor(p.data)
pseudogradient = local_param - self.original_parameters[name].to(p.device)
if isinstance(p, DTensor):
self._grads[name] = pseudogradient
else:
self._grads[name] = pseudogradient

# Make sure tensors are available to `_stream`
if self._stream is not None:
self._stream.wait_stream(torch.cuda.current_stream())

with (
torch.cuda.stream(self._stream)
if self._stream is not None
else nullcontext()
):
# Set the .grad field of each parameter to its pseudogradient
for name, p in self._model_fragment.named_parameters():
local_param = extract_local_tensor(p.data)
pseudogradient = local_param - self.original_parameters[name].to(
p.device
)
if isinstance(p, DTensor):
p.grad._local_tensor = pseudogradient
else:
p.grad = pseudogradient

self._average_grads()

for work in self._allreduce_futures:
work.wait()

if self._stream is not None:
self._stop_event = torch.cuda.Event()
self._stop_event.record()

@torch.profiler.record_function("torchft::local_sgd::perform_sync")
def perform_sync(self) -> bool:
"""
Expand All @@ -322,6 +349,7 @@ def perform_sync(self) -> bool:

if should_commit:
# Use the outer optimizer to update the model parameters
self._set_grads()
self._outer_optimizer.step()
self.save_parameters()
self._outer_optimizer.zero_grad()
Expand All @@ -341,16 +369,16 @@ def _average_grads(self) -> None:

def _allreduce_per_param(self) -> None:
"""Performs allreduce on each gradient tensor separately (original method)."""
for p in self._model_fragment.parameters():
for name, p in self._model_fragment.named_parameters():
# Perform allreduce on the pseudogradients
assert p.grad is not None
if isinstance(p, DTensor):
work = self._manager.allreduce(
p.grad._local_tensor, should_quantize=self.should_quantize
self._grads[name], should_quantize=self.should_quantize
)
else:
work = self._manager.allreduce(
p.grad, should_quantize=self.should_quantize
self._grads[name], should_quantize=self.should_quantize
)
self._allreduce_futures.append(work)

Expand Down Expand Up @@ -609,7 +637,6 @@ def _step_post_hook(
#
# Both of them will fail because Node A didn't send fragment 2
# and Node B didn't send fragment 1.
step = 0
with self._lock:
self._local_step += 1
step = self._local_step
Expand Down
63 changes: 37 additions & 26 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ def allreduce(
# Run the allreduce async and save the work object so we can wait on
# it later.
if should_quantize and IS_TRITON_AVAILABLE:
assert False, "allreduce_quantized is not supported yet"
# TODO: Support `allreduce_quantized`
# fut = allreduce_quantized([tensor], ReduceOp.SUM, self._pg)
fut = allreduce_quantized(
[tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream()
)
else:
work = self._pg.allreduce([tensor], ReduceOp.SUM)
fut = work.get_future()
Expand All @@ -345,22 +345,22 @@ def allreduce(

# schedule grad normalization as a continuation
# on the Future
@torch.profiler.record_function("torchft::manager::allreduce::callback")
def callback(
fut: torch.futures.Future[List[torch.Tensor]],
) -> torch.Tensor:
nonlocal tensor, stream

# check for exceptions
fut.value()

tensor /= self.num_participants()
# change the stream to avoid making the callback stream
# dependent on process group stream running the allreduce
with torch.cuda.stream(stream) if stream is not None else nullcontext():
fut.value()
tensor /= self.num_participants()

if stream is not None:
stream.wait_stream(torch.cuda.current_stream())

return tensor
return tensor

fut = fut.then(callback)

fut = self.wrap_future(fut, tensor)
return fut

Expand Down Expand Up @@ -412,23 +412,27 @@ def wrap_future(
timeout: the timeout for the Future, if None, the manager's timeout will be used
"""

# add a timeout to the future
fut = future_timeout(fut, timeout or self._timeout)

stream: Optional[torch.cuda.Stream] = (
torch.cuda.current_stream() if torch.cuda.is_available() else None
)

# schedule error handling as a continuation on the Future
def callback(
fut: torch.futures.Future[T],
) -> T:
nonlocal default
nonlocal default, stream

try:
return fut.value()
except Exception as e:
self._logger.exception(
f"got exception in future -- skipping remaining: {e}"
)
self.report_error(e)
return default
with torch.cuda.stream(stream) if stream is not None else nullcontext():
try:
return fut.value()
except Exception as e:
self._logger.exception(
f"got exception in future -- skipping remaining: {e}"
)
self.report_error(e)
return default

fut = fut.then(callback)
return fut
Expand Down Expand Up @@ -488,6 +492,7 @@ def start_quorum(
# and don't need to zero_grad
self._healing = False

@torch.profiler.record_function("torchft::manager::wait_quorum")
def wait_quorum(self) -> None:
"""
Wait for the quorum to complete.
Expand Down Expand Up @@ -696,11 +701,17 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
"""
# make sure recovery is complete before committing
if self._recovery_stream is not None:
self._recovery_stream.synchronize()

if torch.cuda.is_available():
torch.cuda.current_stream().synchronize()
with torch.profiler.record_function(
"torchft::manager::should_commmit::recovery_stream::synchronize"
):
if self._recovery_stream is not None:
self._recovery_stream.synchronize()

with torch.profiler.record_function(
"torchft::manager::should_commit::current_stream::synchronize"
):
if torch.cuda.is_available():
torch.cuda.current_stream().synchronize()

if err := self._pg.errored():
self.report_error(err)
Expand Down
Loading
Loading