Skip to content

fix infinite recovery #217

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 60 additions & 2 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,14 +356,47 @@ def perform_sync(self) -> bool:
Overrides the sync method to wait for the scheduled allreduce to finish and
steps using the outer optimizer.
"""
if len(self._allreduce_futures) == 0:
return True
# Waiting for an allreduce before it has been sent is currently not supported.
# Please make sure to not do this to avoid running into inconsistencies.
#
# This can happen when using large values of `fragment_sync_delay`.
# The node might not have participated in syncing of this fragment.
#
# The allreduce for other nodes who did might actually
# succeed and in that case, we shouldn't allow recovery
# from this node.
#
# We do need to increase the `max_step` here so we
# don't end up in an infinite loop of needing to recover
# but we can't let other nodes recover from this node
# because it doesn't have the latest state.
#
# We can add a `is_catching_up` flag to the state_dict
# to disallow recoveries from this node. Such nodes can
# be excluded from `max_step` calculation unless all
# nodes are catching up. This approach makes the replica state
# of global parameters diverge though. So we could add recovery
# for a particular fragment from a peer node as a part of the
# `should_commit` or next `quorum` when a node is catching up.
assert len(self._allreduce_futures) > 0

self.wait()

# Restore the parameters back to the previous state
self.restore_parameters()

# For large values of `fragment_sync_delay`, this call can be
# a problem.
#
# This can return success even if the allreduce failed. Because
# the process group could have been reconfigured while the
# allreduce was inflight. The inflight allreduce may or may
# not have been aborted.
#
# We can track errors per allreduce to
# let the commit fail here. But this has the downside of
# reconfiguring the pg too many times resulting in
# more aborts and more commit failures.
should_commit = self._manager.should_commit()

if should_commit:
Expand Down Expand Up @@ -575,6 +608,13 @@ def __init__(
for i, model_fragment in enumerate(model_fragments)
]

# This is to make sure we adhere to the assumptions made by the
# `_StreamingDiLoCoFragment` about the fragment sync schedule.
assert fragment_sync_delay < sync_every // len(model_fragments)

# Used to ensure that we try to sync a fragment after we've sent a prepare for it
self._first_prepare_sent: set[int] = set()

# Need to copy the parameters to the host to be safe if we are on the first step.
self._save_parameters()

Expand Down Expand Up @@ -618,6 +658,8 @@ def _wait(self) -> None:
for fragment in self._fragments:
fragment.wait()

self._first_prepare_sent.clear()

def _quorum_loop(self) -> None:
"""
Performs infinite retries until quorum is successfull
Expand Down Expand Up @@ -660,12 +702,18 @@ def _step_post_hook(

logger.debug(f"preparing fragment {i} at step {step}")

self._first_prepare_sent.add(i)
fragment.prepare_sync()

for i, fragment in enumerate(self._fragments):
if not fragment.should_sync_fragment(step):
continue

# We need to have sent an allreduce before we can syncing
# a fragment
if i not in self._first_prepare_sent:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a comment when we would get into this case? It looks like we set it in the loop before so i'm confused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the loop before could run for a different fragment (not the one we're syncing) depending on the sync schedule

continue

logger.debug(f"syncing fragment {i} at step {step}")

if not fragment.perform_sync():
Expand Down Expand Up @@ -708,6 +756,16 @@ def _step_post_hook(
# waste after recovery
self._quorum_loop()

# TODO: Since we do quorum after commit, there might be a big gap until
# the next allreduce. This increases the chances of nodes failing
# and so the allreduce to fail.
# - We could maybe do a quorum again right before preparing for a fragment
# using `shrink_only`. This might make it tricky for new nodes to join
# though.
# - Maintain a sequence number in the state dict that gets bumped at every
# quorum call. Then we can do a quorum right before allreduce and avoid
# doing quorums after commit.

# We need to set make sure `_local_step` is still
# the same across all replicas if `quorum_id` changed.
#
Expand Down
2 changes: 1 addition & 1 deletion train_diloco.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ def trace_handler(p):
outer_optimizer,
backup_device=device,
sync_every=20 if USE_STREAMING else 20,
fragment_sync_delay=10 if USE_STREAMING else 0,
fragment_sync_delay=5 if USE_STREAMING else 0,
should_quantize=False,
) as diloco:
while True:
Expand Down