diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index d8a944f..201d372 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -356,14 +356,47 @@ def perform_sync(self) -> bool: Overrides the sync method to wait for the scheduled allreduce to finish and steps using the outer optimizer. """ - if len(self._allreduce_futures) == 0: - return True + # Waiting for an allreduce before it has been sent is currently not supported. + # Please make sure to not do this to avoid running into inconsistencies. + # + # This can happen when using large values of `fragment_sync_delay`. + # The node might not have participated in syncing of this fragment. + # + # The allreduce for other nodes who did might actually + # succeed and in that case, we shouldn't allow recovery + # from this node. + # + # We do need to increase the `max_step` here so we + # don't end up in an infinite loop of needing to recover + # but we can't let other nodes recover from this node + # because it doesn't have the latest state. + # + # We can add a `is_catching_up` flag to the state_dict + # to disallow recoveries from this node. Such nodes can + # be excluded from `max_step` calculation unless all + # nodes are catching up. This approach makes the replica state + # of global parameters diverge though. So we could add recovery + # for a particular fragment from a peer node as a part of the + # `should_commit` or next `quorum` when a node is catching up. + assert len(self._allreduce_futures) > 0 self.wait() # Restore the parameters back to the previous state self.restore_parameters() + # For large values of `fragment_sync_delay`, this call can be + # a problem. + # + # This can return success even if the allreduce failed. Because + # the process group could have been reconfigured while the + # allreduce was inflight. The inflight allreduce may or may + # not have been aborted. + # + # We can track errors per allreduce to + # let the commit fail here. But this has the downside of + # reconfiguring the pg too many times resulting in + # more aborts and more commit failures. should_commit = self._manager.should_commit() if should_commit: @@ -575,6 +608,13 @@ def __init__( for i, model_fragment in enumerate(model_fragments) ] + # This is to make sure we adhere to the assumptions made by the + # `_StreamingDiLoCoFragment` about the fragment sync schedule. + assert fragment_sync_delay < sync_every // len(model_fragments) + + # Used to ensure that we try to sync a fragment after we've sent a prepare for it + self._first_prepare_sent: set[int] = set() + # Need to copy the parameters to the host to be safe if we are on the first step. self._save_parameters() @@ -618,6 +658,8 @@ def _wait(self) -> None: for fragment in self._fragments: fragment.wait() + self._first_prepare_sent.clear() + def _quorum_loop(self) -> None: """ Performs infinite retries until quorum is successfull @@ -660,12 +702,18 @@ def _step_post_hook( logger.debug(f"preparing fragment {i} at step {step}") + self._first_prepare_sent.add(i) fragment.prepare_sync() for i, fragment in enumerate(self._fragments): if not fragment.should_sync_fragment(step): continue + # We need to have sent an allreduce before we can syncing + # a fragment + if i not in self._first_prepare_sent: + continue + logger.debug(f"syncing fragment {i} at step {step}") if not fragment.perform_sync(): @@ -708,6 +756,16 @@ def _step_post_hook( # waste after recovery self._quorum_loop() + # TODO: Since we do quorum after commit, there might be a big gap until + # the next allreduce. This increases the chances of nodes failing + # and so the allreduce to fail. + # - We could maybe do a quorum again right before preparing for a fragment + # using `shrink_only`. This might make it tricky for new nodes to join + # though. + # - Maintain a sequence number in the state dict that gets bumped at every + # quorum call. Then we can do a quorum right before allreduce and avoid + # doing quorums after commit. + # We need to set make sure `_local_step` is still # the same across all replicas if `quorum_id` changed. # diff --git a/train_diloco.py b/train_diloco.py index c221558..1de4ade 100644 --- a/train_diloco.py +++ b/train_diloco.py @@ -201,7 +201,7 @@ def trace_handler(p): outer_optimizer, backup_device=device, sync_every=20 if USE_STREAMING else 20, - fragment_sync_delay=10 if USE_STREAMING else 0, + fragment_sync_delay=5 if USE_STREAMING else 0, should_quantize=False, ) as diloco: while True: