Skip to content

Added proactive heartbeat timeout failure propagation (#164) (#188) #196

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

WarrenZhu050413
Copy link
Contributor

Overview

This PR improves failure detection speed of torchFT through proactive failure recovery. The Manager now listens to Lighthouse failure notifications and aborts hanging collectives immediately instead of waiting for NCCL/Gloo time-outs.

Basic demonstration

You can experiment with proactive failure recovery mode by:

export TORCHFT_PROACTIVE_RECOVERY=1

With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce.

You can test this out by running train_ddp_proactive.py

On shell 1 (one replica groups starts initial training):

export REPLICA_GROUP_ID=0
export NUM_REPLICA_GROUPS=2
export TORCHFT_PROACTIVE_RECOVERY=1

CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py

On shell 2 (a second replica group joins):

export REPLICA_GROUP_ID=1
export NUM_REPLICA_GROUPS=2
export TORCHFT_PROACTIVE_RECOVERY=1

CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py

You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting export TORCHFT_PROACTIVE_RECOVERY=0, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing.

INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Setting error processor thread stop event
INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Waiting for error processor thread to complete
INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Error processor thread shutdown completed.
INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Setting failure listener stop event for process
INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Waiting for failure listener process to complete
INFO:torchft.manager:[train_ddp_0:81a52ce4-d803-4f22-a0c3-54f3b4a88c89/0 - step 10] Failure listener process shutdown completed

And in the Lighthouse you will observe:

2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - Replica train_ddp_1:a581dae2-1ebc-4f93-b882-6477832fef6b timed out (last heartbeat: Instant { tv_sec: 5200692, tv_nsec: 955240591 }), sending failure notification.
2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - Removed replica train_ddp_1:a581dae2-1ebc-4f93-b882-6477832fef6b from heartbeats and participants due to timeout.
2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - New failure detected, resetting all participants for quorum formation.
2025-05-20T22:29:30.029 [INFO] [torchft::lighthouse] - Healthy replicas received failure notification for train_ddp_1:a581dae2-1ebc-4f93-b882-6477832fef6b

Implementation

Implementation Details

Implementation Details:

The proactive failure recovery mechanism involves changes in both the Rust backend and the Python Manager:

Rust:

  • src/lighthouse.rs:
    • The Lighthouse server now includes a failure_channel (a Tokio broadcast channel).
    • When _failure_tick detects a timed-out replica, it broadcasts a FailureNotification on this channel.
    • A new gRPC method, subscribe_failures, is added to LighthouseService. Clients can call this to receive a stream of FailureNotifications.
    • The inject_failure method has been added to the LighthouseServer (Python-exposed) and Lighthouse (Rust struct) to facilitate testing by manually triggering failure notifications.
  • src/lib.rs:
    • A FailureStream class is introduced, wrapping the tonic::Streaming<ProtoFailureNotification>. Its __next__ method allows Python to iterate over failure notifications. This method uses py.allow_threads around a blocking runtime.block_on(fut) call to fetch the next notification, allowing the GIL to be released.

Python (Manager):

  • torchft/manager.py:
    • When proactive_recovery is enabled (via constructor argument or TORCHFT_PROACTIVE_RECOVERY=1 environment variable), the Manager spawns a separate daemon process (_failure_listener_process_main).
    • Subprocess based subscription: This process creates a LighthouseClient and calls subscribe_failures. It then iterates over the received failure notifications.
    • Inter-Process Communication (IPC): _ManagedPipe is used for the listener process to send errors it receives from the Lighthouse through the stream returned by subscribe_failures back to the main Manager process. This mimics the implementation of IPC in BabyProcessGroup
    • Error Listening: A new thread within the main Manager process continuously polls the _error_pipe.
    • Error Response: If an exception is received, it calls self.report_error() and aborts the underlying process group (self._pg.abort()).
    • Error Reporting: self.report_error() is now also used to flag the manager as errored when a proactive failure is detected.
    • Shutdown: Manager.shutdown() is enhanced to gracefully stop the _error_processor_thread and the _failure_listener_process.
    • The subscribe_timeout parameter for subscribe_failures in _failure_listener_process_main allows the listener process to be interruptible for clean shutdown.

Design Rationale

I decided to use a separate process to subscribe to the failure notification because waiting on the failure stream is a blocking call. Because of the GIL, if one waits using a Python thread then it will block the main thread from functioning.

As I was implementing it, I considered three ways to implement this:

  1. GIL Release in Rust Stream Iteration: Decouple the Python logic from the tokio streaming logic so that the GIL can be released in lib.rs.
  2. Asyncio: Use pyo3-asyncio to create an async iterator from tokio-stream.
  3. Multiprocessing: Use a separate process to subscribe to the failure notification.

Approach 1 and 2 are more elegant and should be more efficient as they do not involve spawning a separate process. However, I am limited by my Rust langauge understanding and was unable to implement them.

Tests

I introduced the following tests:

  • Rust:
    • src/lighthouse.rs:
      • test_subscribe_failures_delivers_notifications: Verifies that inject_failure correctly sends a notification that is received by a subscriber.
      • test_failure_tick_single_notification_and_cleanup: Ensures _failure_tick correctly identifies timeouts, broadcasts notifications once, and cleans up state.
  • Python:
    • torchft/lighthouse_test.py:
      • test_subscribe_failures_notification: Python-level test ensuring LighthouseClient.subscribe_failures receives notifications triggered by LighthouseServer.inject_failure.
      • test_inject_failure: Confirms that server.inject_failure() leads to a notification being received by client.subscribe_failures().
    • torchft/manager_test.py:
      • test_manager_error_handler: Tests that the Manager processes exceptions passed to its internal error handler.
      • test_direct_error_pipe: Verifies that an exception sent directly via the IPC pipe is correctly picked up by the Manager.
      • test_manager_failure_e2e: An end-to-end test where LighthouseServer.inject_failure triggers a notification that propagates through the listener process, IPC pipe, and results in the Manager capturing the error.

Linter

I am still getting the following error after running lintrunner -a, but I couldn’t debug it:


  Advice (pyre) command-failed
    Failed due to JSONDecodeError:
    Expecting value: line 1 column 1 (char 0)
Successfully applied all patches.

Other minor changes

Note: In order to test the code using train_ddp.py, I fixed an error introduced by commit 652a009 and changed the api of DistributedSampler to use replica_group_id.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 20, 2025
@WarrenZhu050413 WarrenZhu050413 force-pushed the main branch 2 times, most recently from ebb3953 to 2a7bac7 Compare May 21, 2025 01:31
Copy link
Member

@d4l3k d4l3k left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks pretty good! I need to do another fine grained pass in case I missed something + small style stuff

fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult<FailureNotification> {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: just use self

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PyRef<'_, FailureStream> cannot be used as the type of self without the arbitrary_self_types feature
see issue #44874 rust-lang/rust#44874 for more information
consider changing to self, &self, &mut self, or a type implementing Receiver such as self: Box<Self>, self: Rc<Self>, or self: Arc<Self>rustcClick for full compiler diagnostic

I get this error when use self.

default_value = "1000",
help = "How frequently to check for failures."
)]
pub failure_tick_ms: u64,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any reason to use separate failure_tick_ms instead of just the quorum_tick?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of exposing this as a knob. But don't have any particular reason. Should we only have a tick_ms variable that is shared between quorum_tick and failure_tick?


while not stop_event.is_set():
try:
lighthouse_client = LighthouseClient(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've generally avoided hitting lighthouse directly from each worker and instead use a call tree through the Manager. It's probably for the best to do the same here instead of requiring all workers to hit the Lighthouse

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To confirm my understanding:

  • Each ManagerClient would send a heartbeat to Lighthouse via its ManagerServer.
  • If any client heartbeat times out, the ManagerServer stops forwarding heartbeats.
  • Lighthouse then streams a failure notification back to the ManagerServer, which streams it to clients.

I prototyped this design at first, but in practice it duplicated the streaming logic—once in ManagerClient and again in ManagerServer—and the two paths diverged enough that the code was no longer reusable.

Instead, I adopted the simpler model where each worker talks directly to LighthouseServer, mainly because this would be easier to maintain and reason about.

Another consideration that I had is that this makes torchFT easier to extend beyond HSDP-specific scenarios (e.g., dynamically reconfiguring a pipeline after a failure in other deployment environments).

On the other hand, I agree that having the workers directly hitting the Lighthouse goes against the soft invariants maintained by the code. I would be glad to integrate the call-tree approach if you feel it adds value.

@WarrenZhu050413 WarrenZhu050413 force-pushed the main branch 4 times, most recently from 17f44f4 to daa3adf Compare May 21, 2025 03:12
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants