-
Notifications
You must be signed in to change notification settings - Fork 33
Added proactive heartbeat timeout failure propagation (#164) (#188) #196
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
ebb3953
to
2a7bac7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty good! I need to do another fine grained pass in case I missed something + small style stuff
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { | ||
slf | ||
} | ||
fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult<FailureNotification> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: just use self
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PyRef<'_, FailureStream>
cannot be used as the type of self
without the arbitrary_self_types
feature
see issue #44874 rust-lang/rust#44874 for more information
consider changing to self
, &self
, &mut self
, or a type implementing Receiver
such as self: Box<Self>
, self: Rc<Self>
, or self: Arc<Self>
rustcClick for full compiler diagnostic
I get this error when use self.
default_value = "1000", | ||
help = "How frequently to check for failures." | ||
)] | ||
pub failure_tick_ms: u64, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any reason to use separate failure_tick_ms instead of just the quorum_tick?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was thinking of exposing this as a knob. But don't have any particular reason. Should we only have a tick_ms variable that is shared between quorum_tick and failure_tick?
|
||
while not stop_event.is_set(): | ||
try: | ||
lighthouse_client = LighthouseClient( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've generally avoided hitting lighthouse directly from each worker and instead use a call tree through the Manager. It's probably for the best to do the same here instead of requiring all workers to hit the Lighthouse
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To confirm my understanding:
- Each ManagerClient would send a heartbeat to Lighthouse via its ManagerServer.
- If any client heartbeat times out, the ManagerServer stops forwarding heartbeats.
- Lighthouse then streams a failure notification back to the ManagerServer, which streams it to clients.
I prototyped this design at first, but in practice it duplicated the streaming logic—once in ManagerClient and again in ManagerServer—and the two paths diverged enough that the code was no longer reusable.
Instead, I adopted the simpler model where each worker talks directly to LighthouseServer, mainly because this would be easier to maintain and reason about.
Another consideration that I had is that this makes torchFT easier to extend beyond HSDP-specific scenarios (e.g., dynamically reconfiguring a pipeline after a failure in other deployment environments).
On the other hand, I agree that having the workers directly hitting the Lighthouse goes against the soft invariants maintained by the code. I would be glad to integrate the call-tree approach if you feel it adds value.
17f44f4
to
daa3adf
Compare
3641dca
to
7b550aa
Compare
Overview
This PR improves failure detection speed of torchFT through proactive failure recovery. The Manager now listens to Lighthouse failure notifications and aborts hanging collectives immediately instead of waiting for NCCL/Gloo time-outs.
Basic demonstration
You can experiment with proactive failure recovery mode by:
export TORCHFT_PROACTIVE_RECOVERY=1
With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce.
You can test this out by running
train_ddp_proactive.py
On shell 1 (one replica groups starts initial training):
On shell 2 (a second replica group joins):
You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting
export TORCHFT_PROACTIVE_RECOVERY=0
, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing.And in the Lighthouse you will observe:
Implementation
Implementation Details
Implementation Details:
The proactive failure recovery mechanism involves changes in both the Rust backend and the Python
Manager
:Rust:
src/lighthouse.rs
:Lighthouse
server now includes afailure_channel
(a Tokio broadcast channel)._failure_tick
detects a timed-out replica, it broadcasts aFailureNotification
on this channel.subscribe_failures
, is added toLighthouseService
. Clients can call this to receive a stream ofFailureNotification
s.inject_failure
method has been added to theLighthouseServer
(Python-exposed) andLighthouse
(Rust struct) to facilitate testing by manually triggering failure notifications.src/lib.rs
:FailureStream
class is introduced, wrapping thetonic::Streaming<ProtoFailureNotification>
. Its__next__
method allows Python to iterate over failure notifications. This method usespy.allow_threads
around a blockingruntime.block_on(fut)
call to fetch the next notification, allowing the GIL to be released.Python (Manager):
torchft/manager.py
:proactive_recovery
is enabled (via constructor argument orTORCHFT_PROACTIVE_RECOVERY=1
environment variable), theManager
spawns a separate daemon process (_failure_listener_process_main
).Subprocess based subscription
: This process creates aLighthouseClient
and callssubscribe_failures
. It then iterates over the received failure notifications.Inter-Process Communication (IPC)
:_ManagedPipe
is used for the listener process to send errors it receives from theLighthouse
through the stream returned bysubscribe_failures
back to the mainManager
process. This mimics the implementation of IPC inBabyProcessGroup
Manager
process continuously polls the_error_pipe
.self.report_error()
and aborts the underlying process group (self._pg.abort()
).self.report_error()
is now also used to flag the manager as errored when a proactive failure is detected.Manager.shutdown()
is enhanced to gracefully stop the_error_processor_thread
and the_failure_listener_process
.subscribe_timeout
parameter forsubscribe_failures
in_failure_listener_process_main
allows the listener process to be interruptible for clean shutdown.Design Rationale
I decided to use a separate process to subscribe to the failure notification because waiting on the failure stream is a blocking call. Because of the GIL, if one waits using a Python thread then it will block the main thread from functioning.
As I was implementing it, I considered three ways to implement this:
lib.rs
.pyo3-asyncio
to create an async iterator from tokio-stream.Approach 1 and 2 are more elegant and should be more efficient as they do not involve spawning a separate process. However, I am limited by my Rust langauge understanding and was unable to implement them.
Tests
I introduced the following tests:
src/lighthouse.rs
:test_subscribe_failures_delivers_notifications
: Verifies thatinject_failure
correctly sends a notification that is received by a subscriber.test_failure_tick_single_notification_and_cleanup
: Ensures_failure_tick
correctly identifies timeouts, broadcasts notifications once, and cleans up state.torchft/lighthouse_test.py
:test_subscribe_failures_notification
: Python-level test ensuringLighthouseClient.subscribe_failures
receives notifications triggered byLighthouseServer.inject_failure
.test_inject_failure
: Confirms thatserver.inject_failure()
leads to a notification being received byclient.subscribe_failures()
.torchft/manager_test.py
:test_manager_error_handler
: Tests that theManager
processes exceptions passed to its internal error handler.test_direct_error_pipe
: Verifies that an exception sent directly via the IPC pipe is correctly picked up by theManager
.test_manager_failure_e2e
: An end-to-end test whereLighthouseServer.inject_failure
triggers a notification that propagates through the listener process, IPC pipe, and results in theManager
capturing the error.Linter
I am still getting the following error after running
lintrunner -a
, but I couldn’t debug it:Other minor changes
Note: In order to test the code using train_ddp.py, I fixed an error introduced by commit 652a009 and changed the api of DistributedSampler to use
replica_group_id
.