diff --git a/torchft/collectives.py b/torchft/collectives.py index 582db8c..837fbcd 100644 --- a/torchft/collectives.py +++ b/torchft/collectives.py @@ -135,21 +135,24 @@ def allocate_reduce_scatter_output( return tensor, padded_sizes -class _QuantizedOpFuture(Future[None]): +class _QuantizedOpFuture(Future[list[torch.Tensor]]): def __init__( self, sync_stream: cuda.Stream, keep_alive_tensors: list[torch.Tensor], + return_tensors: list[torch.Tensor], ) -> None: super().__init__() self._sync_stream = sync_stream self._keep_alive_tensors = keep_alive_tensors + self._return_tensors = return_tensors - def wait(self) -> None: + def wait(self) -> list[torch.Tensor]: # Wait for the synchronization to complete. cuda.current_stream().wait_stream(self._sync_stream) # Clean up intermediate buffers. del self._keep_alive_tensors + return self._return_tensors def reduce_scatter_quantized( @@ -276,6 +279,7 @@ def reduce_scatter_quantized( quantized_inputs, quantized_inputs_out, ], + [output], ) @@ -284,7 +288,7 @@ def allreduce_quantized( opts: AllreduceOptions | ReduceOp, process_group: "ProcessGroup", sync_stream: cuda.Stream | None = None, -) -> Future[None]: +) -> Future[list[torch.Tensor]]: """ Performs a quantized all-reduce operation on a list of tensors. @@ -334,7 +338,7 @@ def allreduce_quantized( ) rank = process_group.rank() - world_size = process_group.size() + world_size: int = process_group.size() if sync_stream is None: sync_stream = cuda.Stream() @@ -346,7 +350,7 @@ def allreduce_quantized( with cuda.stream(sync_stream): # Quantize tensoers and compute their scales, all inlined in the # output tensor. - quantized_tensors = fused_quantize_into_fp8(tensors, world_size) + quantized_tensors: torch.Tensor = fused_quantize_into_fp8(tensors, world_size) # Allocate output tensor where all-reduce results will be stored quantized_tensors_out = torch.zeros_like(quantized_tensors) @@ -370,20 +374,22 @@ def allreduce_quantized( ) # Collect reduced chunks from other ranks. - process_group.allgather_into_tensor_coalesced( + work = process_group.allgather_into_tensor_coalesced( [quantized_tensors.view(world_size, -1)], [torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]], _to_allgather_options(allreduce_opts), - ).wait() + ) + work.wait() + fut = work.get_future() - # Dequantize and copy to output buffer. - fused_dequantize_from_fp8(tensors, quantized_tensors, world_size) + def callback(fut: Future[list[torch.Tensor]]) -> list[torch.Tensor]: + # Dequantize and copy to output buffer. + nonlocal tensors, quantized_tensors, world_size, sync_stream - # pyre-ignore[29] - return _QuantizedOpFuture( - sync_stream, - [ - quantized_tensors, - quantized_tensors_out, - ], - ) + with torch.cuda.stream(sync_stream): + # Dequantize the result back to the original precision + fused_dequantize_from_fp8(tensors, quantized_tensors, world_size) + return tensors + + fut = fut.then(callback) + return fut diff --git a/torchft/futures.py b/torchft/futures.py index a78a74b..52bb96e 100644 --- a/torchft/futures.py +++ b/torchft/futures.py @@ -3,7 +3,7 @@ import sys import threading import time -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext from datetime import timedelta from typing import Callable, Generator, Optional, TypeVar from unittest.mock import Mock @@ -162,17 +162,22 @@ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]: handle, ) + stream: Optional[torch.cuda.Stream] = ( + torch.cuda.current_stream() if torch.cuda.is_available() else None + ) + def callback(fut: Future[T]) -> None: - handle.cancel() - try: - timed_fut.set_result(fut.wait()) - except Exception as e: + with torch.cuda.stream(stream) if stream is not None else nullcontext(): + handle.cancel() try: - # this can throw if the future is already done - # pyre-fixme[6]: e is not T - timed_fut.set_exception(e) - except Exception: - pass + timed_fut.set_result(fut.wait()) + except Exception as e: + try: + # this can throw if the future is already done + # pyre-fixme[6]: e is not T + timed_fut.set_exception(e) + except Exception: + pass fut.add_done_callback(callback) return timed_fut diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 2c45591..38adb42 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -203,13 +203,18 @@ def __init__( torch.cuda.Stream() if torch.cuda.is_available() else None ) + # Recorded on `_stream` to wait for allreduce to finish + self._stop_event: Optional[torch.cuda.Event] = None + if bucket_cap_mb is not None: self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024) self.use_bucketization = use_bucketization self.should_quantize = should_quantize + self._grads: Dict[str, torch.Tensor] = {} self.original_parameters: Dict[str, torch.Tensor] = {} + for name, p in self._model_fragment.named_parameters(): if isinstance(p, DTensor): p = extract_local_tensor(p.data) @@ -252,17 +257,30 @@ def restore_parameters(self) -> None: else: p.data.copy_(self.original_parameters[name], non_blocking=False) + def _set_grads(self) -> None: + """ + Sets the gradients of the model fragment from the allreduce result + """ + for name, p in self._model_fragment.named_parameters(): + if isinstance(p, DTensor): + p.grad._local_tensor = self._grads[name] + else: + p.grad = self._grads[name] + + del self._grads[name] + @torch.profiler.record_function("torchft::local_sgd::wait") def wait(self) -> None: """ Waits for the previously scheduled allreduce to finish """ - - for work in self._allreduce_futures: - work.wait() + if len(self._allreduce_futures) == 0: + return if self._stream is not None: - self._stream.synchronize() + assert self._stop_event is not None + self._stop_event.synchronize() + self._stop_event = None self._allreduce_futures = [] @@ -286,24 +304,33 @@ def prepare_sync(self) -> None: Calculate the pseugradient, average them across the manager group and starts allreduce on the pseudo-gradients but doesn't wait for it to finish. """ + # Set the .grad field of each parameter to its pseudogradient + for name, p in self._model_fragment.named_parameters(): + local_param = extract_local_tensor(p.data) + pseudogradient = local_param - self.original_parameters[name].to(p.device) + if isinstance(p, DTensor): + self._grads[name] = pseudogradient + else: + self._grads[name] = pseudogradient + + # Make sure tensors are available to `_stream` + if self._stream is not None: + self._stream.wait_stream(torch.cuda.current_stream()) + with ( torch.cuda.stream(self._stream) if self._stream is not None else nullcontext() ): - # Set the .grad field of each parameter to its pseudogradient - for name, p in self._model_fragment.named_parameters(): - local_param = extract_local_tensor(p.data) - pseudogradient = local_param - self.original_parameters[name].to( - p.device - ) - if isinstance(p, DTensor): - p.grad._local_tensor = pseudogradient - else: - p.grad = pseudogradient - self._average_grads() + for work in self._allreduce_futures: + work.wait() + + if self._stream is not None: + self._stop_event = torch.cuda.Event() + self._stop_event.record() + @torch.profiler.record_function("torchft::local_sgd::perform_sync") def perform_sync(self) -> bool: """ @@ -322,6 +349,7 @@ def perform_sync(self) -> bool: if should_commit: # Use the outer optimizer to update the model parameters + self._set_grads() self._outer_optimizer.step() self.save_parameters() self._outer_optimizer.zero_grad() @@ -341,16 +369,16 @@ def _average_grads(self) -> None: def _allreduce_per_param(self) -> None: """Performs allreduce on each gradient tensor separately (original method).""" - for p in self._model_fragment.parameters(): + for name, p in self._model_fragment.named_parameters(): # Perform allreduce on the pseudogradients assert p.grad is not None if isinstance(p, DTensor): work = self._manager.allreduce( - p.grad._local_tensor, should_quantize=self.should_quantize + self._grads[name], should_quantize=self.should_quantize ) else: work = self._manager.allreduce( - p.grad, should_quantize=self.should_quantize + self._grads[name], should_quantize=self.should_quantize ) self._allreduce_futures.append(work) @@ -609,7 +637,6 @@ def _step_post_hook( # # Both of them will fail because Node A didn't send fragment 2 # and Node B didn't send fragment 1. - step = 0 with self._lock: self._local_step += 1 step = self._local_step diff --git a/torchft/manager.py b/torchft/manager.py index 489cbcc..27fdecd 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -332,9 +332,9 @@ def allreduce( # Run the allreduce async and save the work object so we can wait on # it later. if should_quantize and IS_TRITON_AVAILABLE: - assert False, "allreduce_quantized is not supported yet" - # TODO: Support `allreduce_quantized` - # fut = allreduce_quantized([tensor], ReduceOp.SUM, self._pg) + fut = allreduce_quantized( + [tensor], ReduceOp.SUM, self._pg, torch.cuda.current_stream() + ) else: work = self._pg.allreduce([tensor], ReduceOp.SUM) fut = work.get_future() @@ -345,22 +345,22 @@ def allreduce( # schedule grad normalization as a continuation # on the Future + @torch.profiler.record_function("torchft::manager::allreduce::callback") def callback( fut: torch.futures.Future[List[torch.Tensor]], ) -> torch.Tensor: nonlocal tensor, stream - # check for exceptions - fut.value() - - tensor /= self.num_participants() + # change the stream to avoid making the callback stream + # dependent on process group stream running the allreduce + with torch.cuda.stream(stream) if stream is not None else nullcontext(): + fut.value() + tensor /= self.num_participants() - if stream is not None: - stream.wait_stream(torch.cuda.current_stream()) - - return tensor + return tensor fut = fut.then(callback) + fut = self.wrap_future(fut, tensor) return fut @@ -412,23 +412,27 @@ def wrap_future( timeout: the timeout for the Future, if None, the manager's timeout will be used """ - # add a timeout to the future fut = future_timeout(fut, timeout or self._timeout) + stream: Optional[torch.cuda.Stream] = ( + torch.cuda.current_stream() if torch.cuda.is_available() else None + ) + # schedule error handling as a continuation on the Future def callback( fut: torch.futures.Future[T], ) -> T: - nonlocal default + nonlocal default, stream - try: - return fut.value() - except Exception as e: - self._logger.exception( - f"got exception in future -- skipping remaining: {e}" - ) - self.report_error(e) - return default + with torch.cuda.stream(stream) if stream is not None else nullcontext(): + try: + return fut.value() + except Exception as e: + self._logger.exception( + f"got exception in future -- skipping remaining: {e}" + ) + self.report_error(e) + return default fut = fut.then(callback) return fut @@ -488,6 +492,7 @@ def start_quorum( # and don't need to zero_grad self._healing = False + @torch.profiler.record_function("torchft::manager::wait_quorum") def wait_quorum(self) -> None: """ Wait for the quorum to complete. @@ -696,11 +701,17 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: RuntimeError: if should_commit fails max_retries times in a row and max_retries is set """ # make sure recovery is complete before committing - if self._recovery_stream is not None: - self._recovery_stream.synchronize() - - if torch.cuda.is_available(): - torch.cuda.current_stream().synchronize() + with torch.profiler.record_function( + "torchft::manager::should_commmit::recovery_stream::synchronize" + ): + if self._recovery_stream is not None: + self._recovery_stream.synchronize() + + with torch.profiler.record_function( + "torchft::manager::should_commit::current_stream::synchronize" + ): + if torch.cuda.is_available(): + torch.cuda.current_stream().synchronize() if err := self._pg.errored(): self.report_error(err) diff --git a/torchft/multiprocessing_dummy_context.py b/torchft/multiprocessing_dummy_context.py new file mode 100644 index 0000000..06ec346 --- /dev/null +++ b/torchft/multiprocessing_dummy_context.py @@ -0,0 +1,135 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Multiprocessing Dummy Context +========================= + +This module provides a context-like interface for multiprocessing.dummy, +which is a wrapper around the threading module that provides a multiprocessing-like +interface but uses threads instead of processes. + +This allows code that uses multiprocessing.get_context() to work with +multiprocessing.dummy by providing a compatible interface. +""" + +import multiprocessing.dummy as mp +import threading +from typing import Callable, Iterable, Mapping + + +class DummyContext: + """ + A context-like class for multiprocessing.dummy that mimics the interface + of a context returned by multiprocessing.get_context(). + """ + + def __init__(self, method: object = None) -> None: + """ + Initialize the dummy context. + + Args: + method: Ignored, only for compatibility with multiprocessing.get_context() + """ + pass + + def Process( + self, + group: object = None, + target: Callable[..., object] | None = None, + name: str | None = None, + args: Iterable[object] = (), + kwargs: Mapping[str, object] = {}, + daemon: bool | None = None, + ) -> mp.DummyProcess: + """ + Create a Process using multiprocessing.dummy.Process. + """ + return mp.Process( + group=group, target=target, name=name, args=args, kwargs=kwargs + ) + + def Pipe( + self, duplex: bool = True + ) -> tuple[mp.connection.Connection, mp.connection.Connection]: + """ + Create a Pipe using multiprocessing.dummy.Pipe. + """ + return mp.Pipe(duplex) + + def Queue(self, maxsize: int = 0) -> mp.Queue: + """ + Create a Queue using multiprocessing.dummy.Queue. + """ + return mp.Queue(maxsize) + + def Event(self) -> threading.Event: + """ + Create an Event using multiprocessing.dummy.Event. + """ + return mp.Event() + + def Lock(self) -> threading.Lock: + """ + Create a Lock using multiprocessing.dummy.Lock. + """ + return mp.Lock() + + def RLock(self) -> threading.RLock: + """ + Create an RLock using multiprocessing.dummy.RLock. + """ + return mp.RLock() + + def Semaphore(self, value: int = 1) -> threading.Semaphore: + """ + Create a Semaphore using multiprocessing.dummy.Semaphore. + """ + return mp.Semaphore(value) + + def BoundedSemaphore(self, value: int = 1) -> threading.BoundedSemaphore: + """ + Create a BoundedSemaphore using multiprocessing.dummy.BoundedSemaphore. + """ + return mp.BoundedSemaphore(value) + + def Condition( + self, lock: threading.Lock | threading.RLock | None = None + ) -> threading.Condition: + """ + Create a Condition using multiprocessing.dummy.Condition. + """ + return mp.Condition(lock) + + def Manager(self) -> object: + """ + Create a Manager using multiprocessing.dummy.Manager. + """ + return mp.Manager() + + +def get_context(method: object = None) -> DummyContext: + """ + Return a context object for multiprocessing.dummy. + + This function mimics multiprocessing.get_context() but returns a DummyContext + that works with multiprocessing.dummy. This can be used to patch + multiprocessing.dummy like so + + + ``` + import multiprocessing.dummy as mp + from torchft.multiprocessing_dummy_context import get_context + mp.get_context = get_context + ``` + + Args: + method: Ignored, only for compatibility with multiprocessing.get_context() + + Returns: + A DummyContext instance + """ + return DummyContext(method) diff --git a/train_diloco.py b/train_diloco.py index e53b452..58654f2 100644 --- a/train_diloco.py +++ b/train_diloco.py @@ -57,7 +57,7 @@ def state_dict(): device = "cuda" if torch.cuda.is_available() else "cpu" pg = ( - ProcessGroupBabyNCCL( + ProcessGroupNCCL( timeout=timedelta(seconds=10), ) if torch.cuda.is_available() @@ -196,6 +196,7 @@ def trace_handler(p): backup_device=device, sync_every=20 if USE_STREAMING else 20, fragment_sync_delay=10 if USE_STREAMING else 0, + should_quantize=False, ) as diloco: while True: for i, (inputs, labels) in enumerate(trainloader): @@ -216,7 +217,7 @@ def trace_handler(p): if manager.current_step() % 100 == 0: print(f"[{manager.current_step()}] loss = {loss.item()}") - if manager.current_step() >= 50: + if manager.current_step() >= 15: # complete training prof.stop() exit()