diff --git a/Cargo.toml b/Cargo.toml index 0c6ae6e..8968731 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,12 +16,16 @@ prost = "0.13.3" prost-types = "0.13.3" pyo3 = {version = "0.24", features = ["extension-module"]} rand = "0.8.5" +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" slog = "2.7.0" slog-stdlog = "4.1.1" stderrlog = "0.6.0" structopt = "0.3.26" tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] } +tokio-stream = {version = "0.1.14", features = ["sync"]} tonic = "0.12.2" +futures-core = "0.3" [build-dependencies] tonic-build = "0.12.2" diff --git a/README.md b/README.md index cb07b47..ff3cfea 100644 --- a/README.md +++ b/README.md @@ -246,6 +246,38 @@ CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --mast By observing the outputs from both shells, you should observe process group reconfiguration and live checkpoint recovery. +### Proactive Failure Recovery Mode (Experimental) + +You can experiment with proactive failure recovery mode by: + +```sh +export TORCHFT_PROACTIVE_RECOVERY=1 +``` + +With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce. + +You can test this out by running `train_ddp_proactive.py` + +On shell 1 (one replica groups starts initial training): +```sh +export REPLICA_GROUP_ID=0 +export NUM_REPLICA_GROUPS=2 +export TORCHFT_PROACTIVE_RECOVERY=1 + +CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py +``` + +On shell 2 (a second replica group joins): +```sh +export REPLICA_GROUP_ID=1 +export NUM_REPLICA_GROUPS=2 +export TORCHFT_PROACTIVE_RECOVERY=1 + +CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py +``` + +You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting `export TORCHFT_PROACTIVE_RECOVERY=0`, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing. + ### Example Parameter Server torchft has a fault tolerant parameter server implementation built on it's diff --git a/proto/torchft.proto b/proto/torchft.proto index 7c086eb..91cc5fb 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -67,9 +67,24 @@ message LighthouseHeartbeatRequest { message LighthouseHeartbeatResponse {} +message SubscribeFailuresRequest {} + +message FailureNotification { + string replica_id = 1; + string error_message = 2; +} + +message LighthouseConfigRequest {} + +message LighthouseConfigResponse { + map config_data = 1; +} + service LighthouseService { rpc Quorum (LighthouseQuorumRequest) returns (LighthouseQuorumResponse); rpc Heartbeat (LighthouseHeartbeatRequest) returns (LighthouseHeartbeatResponse); + rpc SubscribeFailures (SubscribeFailuresRequest) returns (stream FailureNotification); + rpc GetConfig (LighthouseConfigRequest) returns (LighthouseConfigResponse); } message ManagerQuorumRequest { @@ -126,3 +141,9 @@ service ManagerService { rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse); rpc Kill(KillRequest) returns (KillResponse); } + +message LighthouseClientRequest { + string replica_id = 1; +} + + diff --git a/src/lib.rs b/src/lib.rs index 32a7a37..0d9d547 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,7 +13,7 @@ mod timeout; use anyhow::Result; use atty::Stream; use core::time::Duration; -use pyo3::exceptions::{PyRuntimeError, PyTimeoutError}; +use pyo3::exceptions::{PyRuntimeError, PyStopIteration, PyTimeoutError}; use std::cmp; use std::env; use std::sync::Arc; @@ -21,6 +21,7 @@ use std::thread::available_parallelism; use structopt::StructOpt; use tokio::runtime::Runtime; use tokio::task::JoinHandle; +use tokio_stream::StreamExt; use tonic::transport::Channel; use tonic::Status; @@ -35,11 +36,13 @@ pub mod torchftpb { use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::manager_service_client::ManagerServiceClient; use crate::torchftpb::{ - CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest, - ManagerQuorumRequest, ShouldCommitRequest, + CheckpointMetadataRequest, FailureNotification as ProtoFailureNotification, + LighthouseConfigRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest, + ManagerQuorumRequest, ShouldCommitRequest, SubscribeFailuresRequest, }; use pyo3::prelude::*; use pyo3::types::{PyDict, PyString}; +use pyo3::{PyRef, PyRefMut}; // Get the number of threads to use for the tokio runtime fn num_threads() -> usize { @@ -290,6 +293,45 @@ struct QuorumResult { heal: bool, } +#[pyclass(unsendable)] +struct FailureStream { + runtime: Arc, + stream: tonic::Streaming, + timeout: Duration, +} + +#[pymethods] +impl FailureStream { + fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult { + let runtime = slf.runtime.clone(); + let timeout = slf.timeout; + // borrow stream mutably for the whole async block + let fut = async { tokio::time::timeout(timeout, slf.stream.next()).await }; + + match runtime.block_on(fut) { + Ok(Some(Ok(note))) => Ok(FailureNotification { + replica_id: note.replica_id, + error_message: note.error_message, + }), + Ok(Some(Err(status))) => Err(StatusError(status).into()), + Ok(None) => Err(PyStopIteration::new_err(())), + Err(_) => Err(PyTimeoutError::new_err( + "Timeout waiting for failure notification", + )), + } + } +} + +#[pyclass(get_all, set_all)] +#[derive(Clone)] +struct FailureNotification { + replica_id: String, + error_message: String, +} + #[pymethods] impl QuorumResult { #[new] @@ -478,7 +520,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult { #[pyclass] struct LighthouseClient { client: LighthouseServiceClient, - runtime: Runtime, + runtime: Arc, } #[pymethods] @@ -487,11 +529,13 @@ impl LighthouseClient { #[new] fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult { py.allow_threads(move || { - let runtime = tokio::runtime::Builder::new_multi_thread() - .worker_threads(num_threads()) - .thread_name("torchft-lhclnt") - .enable_all() - .build()?; + let runtime = Arc::new( + tokio::runtime::Builder::new_multi_thread() + .worker_threads(num_threads()) + .thread_name("torchft-lhclnt") + .enable_all() + .build()?, + ); let client = runtime .block_on(manager::lighthouse_client_new(addr, connect_timeout)) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -586,6 +630,44 @@ impl LighthouseClient { Ok(()) }) } + + #[pyo3(signature = (timeout = Duration::from_secs(5)))] + fn subscribe_failures(&self, py: Python<'_>, timeout: Duration) -> PyResult { + py.allow_threads(move || { + let req = tonic::Request::new(SubscribeFailuresRequest {}); + let response = self + .runtime + .block_on(self.client.clone().subscribe_failures(req)) + .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + Ok(FailureStream { + runtime: self.runtime.clone(), + stream: response.into_inner(), + timeout: timeout, + }) + }) + } + + /// Get configuration data from the lighthouse server. + /// + /// Returns: + /// dict[str, str]: The configuration data as a dictionary. + /// + /// Args: + /// timeout (timedelta, optional): Per-RPC deadline. Default = 5 s. + #[pyo3(signature = (timeout = Duration::from_secs(5)))] + fn get_config( + &self, + py: Python<'_>, + timeout: Duration, + ) -> Result, StatusError> { + py.allow_threads(move || { + let mut req = tonic::Request::new(LighthouseConfigRequest {}); + req.set_timeout(timeout); + let response = self.runtime.block_on(self.client.clone().get_config(req))?; + let config_data = response.into_inner().config_data; + Ok(config_data) + }) + } } /// LighthouseServer is a GRPC server for the lighthouse service. @@ -601,6 +683,7 @@ impl LighthouseClient { /// join_timeout_ms (int): The timeout for joining the quorum. /// quorum_tick_ms (int): The interval at which the quorum is checked. /// heartbeat_timeout_ms (int): The timeout for heartbeats. +/// lighthouse_config (str, optional): Path to configuration file (JSON format). #[pyclass] struct LighthouseServer { lighthouse: Arc, @@ -610,7 +693,7 @@ struct LighthouseServer { #[pymethods] impl LighthouseServer { - #[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))] + #[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None, failure_tick_ms=None, lighthouse_config=None))] #[new] fn new( py: Python<'_>, @@ -619,10 +702,13 @@ impl LighthouseServer { join_timeout_ms: Option, quorum_tick_ms: Option, heartbeat_timeout_ms: Option, + failure_tick_ms: Option, + lighthouse_config: Option, ) -> PyResult { let join_timeout_ms = join_timeout_ms.unwrap_or(100); let quorum_tick_ms = quorum_tick_ms.unwrap_or(100); let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000); + let failure_tick_ms = failure_tick_ms.unwrap_or(1000); py.allow_threads(move || { let rt = tokio::runtime::Builder::new_multi_thread() @@ -638,6 +724,8 @@ impl LighthouseServer { join_timeout_ms: join_timeout_ms, quorum_tick_ms: quorum_tick_ms, heartbeat_timeout_ms: heartbeat_timeout_ms, + failure_tick_ms: failure_tick_ms, + lighthouse_config: lighthouse_config, })) .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; @@ -663,6 +751,22 @@ impl LighthouseServer { self.handle.abort(); }) } + + /// inject_failure broadcasts a failure notification for the given replica. + /// + /// This helper is intended for testing `subscribe_failures` from Python. + #[pyo3(signature = (replica_id))] + fn inject_failure(&self, py: Python<'_>, replica_id: String) { + let lighthouse = self.lighthouse.clone(); + let runtime = &self._runtime; + py.allow_threads(move || { + let _ = runtime.block_on(async { + if let Err(e) = lighthouse.inject_failure(replica_id).await { + eprintln!("Failed to inject failure: {}", e); + } + }); + }); + } } struct StatusError(Status); @@ -750,6 +854,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?; Ok(()) diff --git a/src/lighthouse.rs b/src/lighthouse.rs index a576003..4c00513 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -12,6 +12,7 @@ use std::sync::Arc; use std::time::Duration; use std::time::{Instant, SystemTime}; +use crate::torchftpb::FailureNotification; use anyhow::{anyhow, Result}; use askama::Template; use axum::{ @@ -21,25 +22,38 @@ use axum::{ routing::{get, post}, Router, }; +use chrono; use gethostname::gethostname; -use log::{error, info}; +use log::{error, info, warn}; +use serde_json; +use std::fs; use structopt::StructOpt; use tokio::sync::broadcast; use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio::time::interval; +use tokio_stream::wrappers::{ + errors::BroadcastStreamRecvError as TokioStreamBroadcastStreamRecvError, BroadcastStream, +}; +use tokio_stream::StreamExt; use tonic::service::Routes; use tonic::transport::server::TcpIncoming; use tonic::transport::Server; use tonic::{Request, Response, Status}; +use futures_core::Stream; +use std::pin::Pin; + use crate::manager::manager_client_new; use crate::torchftpb::{ lighthouse_service_server::{LighthouseService, LighthouseServiceServer}, - KillRequest, LighthouseHeartbeatRequest, LighthouseHeartbeatResponse, LighthouseQuorumRequest, - LighthouseQuorumResponse, Quorum, QuorumMember, + KillRequest, LighthouseConfigRequest, LighthouseConfigResponse, LighthouseHeartbeatRequest, + LighthouseHeartbeatResponse, LighthouseQuorumRequest, LighthouseQuorumResponse, Quorum, + QuorumMember, SubscribeFailuresRequest, }; +use serde::Deserialize; + #[derive(Clone)] struct QuorumMemberDetails { joined: Instant, @@ -47,14 +61,31 @@ struct QuorumMemberDetails { } struct State { - channel: broadcast::Sender, + quorum_channel: broadcast::Sender, + // Tracks currently active participants in the process of forming a quorum. + // Replicas are added upon receiving a `LighthouseQuorumRequest`. + // Replicas are cleared after a quorum is successfully formed OR + // removed by `_failure_tick` if their heartbeat expires. participants: HashMap, prev_quorum: Option, quorum_id: i64, - // heartbeat information - // replica_id -> last heartbeat + // Stores the last heartbeat time for each replica ID. + // Replicas are added/updated upon receiving `LighthouseHeartbeatRequest` or `LighthouseQuorumRequest`. + // Replicas are removed by `_failure_tick` if their heartbeat expires and a failure notification is sent. heartbeats: HashMap, + + // Stores the timestamp of when a replica was first detected as failed (heartbeat expired). + // This is used to ensure only one `FailureNotification` is sent per failure event. + // Replicas are added by `_failure_tick` upon detecting a new failure. + // Replicas are removed by `_failure_tick` if a subsequent heartbeat is received (signifying recovery). + failures: HashMap, + + // Broadcast channel for sending failure notifications to subscribers. + pub failure_channel: broadcast::Sender, + + // Configuration data as serde_json::Map (loaded from config file if provided) + config_data: serde_json::Map, } pub struct Lighthouse { @@ -83,7 +114,7 @@ impl ChangeLogger { } } -#[derive(StructOpt, Debug)] +#[derive(StructOpt, Debug, Clone)] #[structopt()] pub struct LighthouseOpt { // bind is the address to bind the server to. @@ -120,6 +151,19 @@ pub struct LighthouseOpt { help = "How long to wait for a heartbeat before considering a replica dead." )] pub heartbeat_timeout_ms: u64, + + #[structopt( + long = "failure_tick_ms", + default_value = "1000", + help = "How frequently to check for failures." + )] + pub failure_tick_ms: u64, + + #[structopt( + long = "lighthouse_config", + help = "Path to configuration file (JSON format)" + )] + pub lighthouse_config: Option, } fn quorum_changed(a: &Vec, b: &Vec) -> bool { @@ -260,19 +304,89 @@ fn quorum_compute( ) } +fn load_config(config_path: &Option) -> serde_json::Map { + match config_path { + Some(path) => { + match fs::read_to_string(path) { + Ok(content) => { + // Parse JSON into Map + match serde_json::from_str::>( + &content, + ) { + Ok(json_map) => { + info!("Successfully loaded config from {}", path); + json_map + } + Err(e) => { + warn!( + "Invalid JSON in config file {}: {}. Using empty config.", + path, e + ); + serde_json::Map::new() + } + } + } + Err(e) => { + warn!( + "Failed to read config file {}: {}. Using empty config.", + path, e + ); + serde_json::Map::new() + } + } + } + None => serde_json::Map::new(), + } +} + impl Lighthouse { pub async fn new(opt: LighthouseOpt) -> Result> { let listener = tokio::net::TcpListener::bind(&opt.bind).await?; + // Load configuration data + let config_data = load_config(&opt.lighthouse_config); + let (tx, _) = broadcast::channel(16); + let (failure_tx, failure_rx) = broadcast::channel::(16); + + // Create a task to monitor the failure channel + let mut failure_rx_cloned: broadcast::Receiver = + failure_rx.resubscribe(); + tokio::spawn(async move { + use tokio::time::{sleep, Duration}; + info!("Starting permanent failure channel subscriber"); + loop { + match failure_rx_cloned.recv().await { + Ok(note) => { + info!( + "Healthy replicas received failure notification for {} with error message: {}", + note.replica_id, + note.error_message + ); + } + Err(e) => { + error!("Healthy replicas error: {}", e); + // If the channel is closed, break the loop + if matches!(e, tokio::sync::broadcast::error::RecvError::Closed) { + break; + } + } + } + sleep(Duration::from_millis(100)).await; // Prevent thrashing if there are continuous errors + } + info!("Permanent failure channel subscriber exiting"); + }); Ok(Arc::new(Self { state: Mutex::new(State { participants: HashMap::new(), - channel: tx, + quorum_channel: tx, prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: failure_tx, + config_data: config_data, }), opt: opt, local_addr: listener.local_addr()?, @@ -326,7 +440,7 @@ impl Lighthouse { state.prev_quorum = Some(quorum.clone()); state.participants.clear(); - match state.channel.send(quorum) { + match state.quorum_channel.send(quorum) { Ok(_) => (), Err(e) => error!("failed to send quorum {}", e), } @@ -371,6 +485,19 @@ impl Lighthouse { move || async { self_clone.get_status().await } }), ) + .route( + "/config", + get({ + let self_clone = self.clone(); + move || async { self_clone.get_config_page().await } + }) + .put({ + let self_clone = self.clone(); + move |form_data: axum::extract::Form| async move { + self_clone.update_config(form_data).await + } + }), + ) .route( "/replica/:replica_id/kill", post({ @@ -391,6 +518,76 @@ impl Lighthouse { .map_err(|e| e.into()) } + async fn _run_failure_tick(self: Arc) -> Result<()> { + let mut interval = interval(Duration::from_millis(self.opt.failure_tick_ms)); + loop { + interval.tick().await; // Wait for the next tick + let mut state = self.state.lock().await; + self.clone()._failure_tick(&mut state)?; + } + } + + fn _failure_tick(self: Arc, state: &mut State) -> Result<()> { + let now = Instant::now(); + let timeout = Duration::from_millis(self.opt.heartbeat_timeout_ms); + + // Use a temporary list to collect replica IDs to remove from heartbeats + // to avoid modifying the map while iterating over it. + let mut failed_replica_ids_to_remove_from_heartbeats = Vec::new(); + let mut failure_detected = false; + + for (replica_id, last_heartbeat) in state.heartbeats.iter() { + if now.duration_since(*last_heartbeat) > timeout { + if !state.failures.contains_key(replica_id) { + info!( + "Replica {} timed out (last heartbeat: {:?}), sending failure notification.", + replica_id, + last_heartbeat + ); + if let Err(e) = state.failure_channel.send(FailureNotification { + replica_id: replica_id.clone(), + error_message: "heartbeat timeout".to_string(), + }) { + error!( + "Failed to send failure notification for {}: {} (receiver count: {})", + replica_id, + e, + state.failure_channel.receiver_count() + ); + } else { + failure_detected = true; // Set flag if notification sent successfully + } + // Record failure information + state.failures.insert(replica_id.clone(), now); + state.participants.remove(replica_id); + failed_replica_ids_to_remove_from_heartbeats.push(replica_id.clone()); + } + } else { + // If the participant sends heartbeat again, remove it from failures. + if state.failures.remove(replica_id).is_some() { + info!("Replica {} recovered from failure.", replica_id); + } + } + } + + // Remove failed replicas from heartbeats + for replica_id in failed_replica_ids_to_remove_from_heartbeats { + state.heartbeats.remove(&replica_id); + info!( + "Removed replica {} from heartbeats and participants due to timeout.", + replica_id + ); + } + + // If a new failure was detected and broadcasted, reset participants to restart quorum formation + if failure_detected { + info!("New failure detected, resetting all participants for quorum formation."); + state.participants.clear(); + } + + Ok(()) + } + pub async fn run(self: Arc) -> Result<()> { let mut set = JoinSet::new(); @@ -398,6 +595,8 @@ impl Lighthouse { set.spawn(self.clone()._run_grpc()); + set.spawn(self.clone()._run_failure_tick()); + while let Some(res) = set.join_next().await { res??; } @@ -469,6 +668,98 @@ impl Lighthouse { Ok(()) } + + async fn get_config_page(self: Arc) -> Html { + self.get_config_page_with_message("".to_string()).await + } + + async fn get_config_page_with_message( + self: Arc, + success_message: String, + ) -> Html { + let config_data = { + let state = self.state.lock().await; + // Serialize Map to JSON string for the web interface + match serde_json::to_string_pretty(&state.config_data) { + Ok(json_str) => json_str, + Err(_) => "{}".to_string(), + } + }; + + let timestamp = chrono::Utc::now() + .format("%Y-%m-%d %H:%M:%S UTC") + .to_string(); + + let template = ConfigTemplate { + config_data: config_data, + timestamp: timestamp, + success_message: if success_message.is_empty() { + None + } else { + Some(success_message) + }, + }; + Html(template.render().unwrap()) + } + + async fn update_config( + self: Arc, + axum::extract::Form(form_data): axum::extract::Form, + ) -> Result, AppError> { + let new_config_json = form_data.config; + + info!("Update config called with: {}", new_config_json); + + // Parse and validate the JSON into a Map + let new_config_map = match serde_json::from_str::>( + &new_config_json, + ) { + Ok(json_map) => json_map, + Err(e) => { + warn!("Invalid JSON provided via web interface: {}", e); + return Err(AppError(anyhow!("Invalid JSON: {}", e))); + } + }; + + // Update the config in the lighthouse state + { + let mut state = self.state.lock().await; + state.config_data = new_config_map.clone(); + } + + // Log the updated configuration content + match serde_json::to_string_pretty(&new_config_map) { + Ok(pretty_json) => { + info!( + "Config updated successfully via web interface. New configuration:\n{}", + pretty_json + ); + } + Err(_) => { + info!( + "Config updated successfully via web interface. New configuration: {:?}", + new_config_map + ); + } + } + + // Return the updated config page with success message + Ok(self + .get_config_page_with_message("Configuration updated successfully!".to_string()) + .await) + } + + pub async fn inject_failure(self: Arc, replica_id: String) -> Result<()> { + let state = self.state.lock().await; + state + .failure_channel + .send(FailureNotification { + replica_id, + error_message: "injected failure".to_string(), + }) + .map_err(|e| anyhow!("Failed to send failure notification: {}", e))?; + Ok(()) + } } #[tonic::async_trait] @@ -502,7 +793,7 @@ impl LighthouseService for Arc { member: requester.clone(), }, ); - let rx = state.channel.subscribe(); + let rx = state.quorum_channel.subscribe(); // proactively run quorum tick self.clone() @@ -556,6 +847,59 @@ impl LighthouseService for Arc { let reply = LighthouseHeartbeatResponse {}; Ok(Response::new(reply)) } + + type SubscribeFailuresStream = + Pin> + Send + 'static>>; + + async fn subscribe_failures( + &self, + _req: Request, + ) -> Result, Status> { + // clone a receiver + let rx = { + let state = self.state.lock().await; + let receiver_count = state.failure_channel.receiver_count(); + info!( + "subscribe_failures: Creating new subscriber (current count: {})", + receiver_count + ); + state.failure_channel.subscribe() + }; + + // Wrap the receiver; map its *internal* error into `tonic::Status` + let stream = BroadcastStream::new(rx).filter_map(|res| match res { + Ok(note) => Some(Ok(note)), + Err(TokioStreamBroadcastStreamRecvError::Lagged(n)) => Some(Err( + Status::resource_exhausted(format!("client lagged {n} messages")), + )), + }); + + Ok(Response::new(Box::pin(stream))) + } + + async fn get_config( + &self, + _request: Request, + ) -> Result, Status> { + let config_data = { + let state = self.state.lock().await; + // Convert serde_json::Map to HashMap for protobuf + state + .config_data + .iter() + .map(|(k, v)| { + let value_str = match v { + serde_json::Value::String(s) => s.clone(), + _ => v.to_string(), + }; + (k.clone(), value_str) + }) + .collect() + }; + + let reply = LighthouseConfigResponse { config_data }; + Ok(Response::new(reply)) + } } #[derive(Template)] @@ -576,6 +920,14 @@ struct StatusTemplate { old_age_threshold: Instant, } +#[derive(Template)] +#[template(path = "config.html")] +struct ConfigTemplate { + config_data: String, + timestamp: String, + success_message: Option, +} + // Make our own error that wraps `anyhow::Error`. struct AppError(anyhow::Error); @@ -601,10 +953,17 @@ where } } +#[derive(Deserialize)] +struct ConfigUpdateForm { + config: String, +} + #[cfg(test)] mod tests { use super::*; use std::ops::Sub; + use tokio::sync::broadcast::error::RecvError as TokioBroadcastRecvError; + use tokio::time::timeout as tokio_timeout; use tonic::transport::Channel; @@ -624,14 +983,19 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, + config_data: serde_json::Map::new(), }; let now = Instant::now(); @@ -703,14 +1067,19 @@ mod tests { join_timeout_ms: 0, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, + config_data: serde_json::Map::new(), }; let now = Instant::now(); @@ -789,14 +1158,19 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, + config_data: serde_json::Map::new(), }; let now = Instant::now(); @@ -879,14 +1253,19 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, + config_data: serde_json::Map::new(), }; let now = Instant::now(); @@ -974,6 +1353,8 @@ mod tests { join_timeout_ms: 1, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }; let lighthouse = Lighthouse::new(opt).await?; @@ -1020,14 +1401,19 @@ mod tests { join_timeout_ms: 60 * 60 * 1000, // 1hr quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }; let mut state = State { - channel: broadcast::channel(16).0, + quorum_channel: broadcast::channel(16).0, participants: HashMap::new(), prev_quorum: None, quorum_id: 0, heartbeats: HashMap::new(), + failures: HashMap::new(), + failure_channel: broadcast::channel(16).0, + config_data: serde_json::Map::new(), }; let now = Instant::now(); @@ -1103,6 +1489,186 @@ mod tests { assert!(quorum_changed(&a, &c)); } + // Helper to create a default QuorumMember for tests + fn test_quorum_member(replica_id: &str) -> QuorumMember { + QuorumMember { + replica_id: replica_id.to_string(), + address: format!("addr_{}", replica_id), + store_address: format!("store_{}", replica_id), + step: 1, + world_size: 2, // Assuming 2 for this test context + shrink_only: false, + data: String::new(), + commit_failures: 0, + } + } + + /// Test that `_failure_tick` correctly identifies timed-out replicas, + /// broadcasts a failure notification exactly once per failure, and + /// cleans up the replica from `heartbeats` and `participants` while + /// adding it to `failures`. Subsequent ticks should not re-notify + /// or change the state for an already failed replica. + #[tokio::test] + async fn test_failure_tick_single_notification_and_cleanup() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 0, // Not relevant for this test + quorum_tick_ms: 10, // Not directly relevant but keep it small + heartbeat_timeout_ms: 100, // Reasonably short for testing + failure_tick_ms: 50, // How often _failure_tick would be called + lighthouse_config: None, + }; + let lighthouse = Lighthouse::new(opt.clone()).await?; + + let mut failure_rx = { + let state_guard = lighthouse.state.lock().await; + state_guard.failure_channel.subscribe() + }; + + let replica_id_failing = "failing_one"; + + let now = Instant::now(); + // Ensure expired_time is definitively older than heartbeat_timeout_ms + let expired_time = now - Duration::from_millis(opt.heartbeat_timeout_ms * 2); + + // Setup initial state: one about to fail + { + let mut state_guard = lighthouse.state.lock().await; + let state = &mut *state_guard; + + // Failing replica + state.participants.insert( + replica_id_failing.to_string(), + QuorumMemberDetails { + joined: now, // Joined time doesn't prevent failure due to heartbeat + member: test_quorum_member(replica_id_failing), + }, + ); + state + .heartbeats + .insert(replica_id_failing.to_string(), expired_time); + } + + // --- First call to _failure_tick --- + // This call should detect the failure, send a notification, and update state. + { + let mut state_guard = lighthouse.state.lock().await; + lighthouse.clone()._failure_tick(&mut *state_guard)?; + } + + // Assertions after first tick + // 1. Check notification for failing_replica + match tokio_timeout( + Duration::from_millis(opt.failure_tick_ms * 2), + failure_rx.recv(), + ) + .await + { + Ok(Ok(notification)) => { + assert_eq!( + notification.replica_id, replica_id_failing, + "Notification should be for the failing replica" + ); + } + Ok(Err(TokioBroadcastRecvError::Lagged(n))) => { + panic!( + "Broadcast channel lagged by {} messages, missed the failure notification", + n + ); + } + Ok(Err(TokioBroadcastRecvError::Closed)) => { + panic!("Broadcast channel closed unexpectedly after first tick"); + } + Err(_) => panic!( + "Did not receive failure notification for {} in time", + replica_id_failing + ), + } + + // 2. Verify state changes + { + let state_guard = lighthouse.state.lock().await; + let state = &*state_guard; + + // Failing replica assertions + assert!( + state.failures.contains_key(replica_id_failing), + "{} should be in failures map", + replica_id_failing + ); + assert!( + !state.heartbeats.contains_key(replica_id_failing), + "{} should be removed from heartbeats", + replica_id_failing + ); + assert!( + !state.participants.contains_key(replica_id_failing), + "{} should be removed from participants", + replica_id_failing + ); + } + + // --- Second call to _failure_tick --- + // This call should *not* detect a *new* failure for the same replica + // and should not send another notification. + { + let mut state_guard = lighthouse.state.lock().await; + lighthouse.clone()._failure_tick(&mut *state_guard)?; + } + + // Assertions after second tick + // 1. No new notification for failing_replica + match tokio_timeout( + Duration::from_millis(opt.failure_tick_ms * 2), + failure_rx.recv(), + ) + .await + { + Ok(Ok(notification)) => { + panic!( + "Received unexpected second failure notification for {}", + notification.replica_id + ); + } + Ok(Err(TokioBroadcastRecvError::Lagged(n))) => { + // This might happen if the test environment is slow and ticks are processed faster than receives. + // For this specific assertion (no *new* message), lagging is an acceptable outcome. + info!("Broadcast channel lagged by {} messages on second check, implies no new distinct message.", n); + } + Ok(Err(TokioBroadcastRecvError::Closed)) => { + // Channel might close if sender is dropped, implies no new message. + info!("Broadcast channel closed on second check, implies no new distinct message."); + } + Err(_) => { + // Expected: Timeout, meaning no new message was received for failing_replica. + } + } + + // 2. Verify state remains consistent for failing_replica + { + let state_guard = lighthouse.state.lock().await; + let state = &*state_guard; + + assert!( + state.failures.contains_key(replica_id_failing), + "{} should remain in failures map", + replica_id_failing + ); + assert!( + !state.heartbeats.contains_key(replica_id_failing), + "{} should remain removed from heartbeats", + replica_id_failing + ); + assert!( + !state.participants.contains_key(replica_id_failing), + "{} should remain removed from participants", + replica_id_failing + ); + } + Ok(()) + } + #[tokio::test] async fn test_lighthouse_join_during_shrink() -> Result<()> { fn create_member(id: &str, addr_num: &str, step: i64, shrink_only: bool) -> QuorumMember { @@ -1130,6 +1696,8 @@ mod tests { join_timeout_ms: 1000, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }; // Start the lighthouse service @@ -1237,6 +1805,8 @@ mod tests { join_timeout_ms: 1000, quorum_tick_ms: 10, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }; // Start the lighthouse service @@ -1281,4 +1851,180 @@ mod tests { lighthouse_task.abort(); Ok(()) } + + #[tokio::test] + async fn test_lighthouse_subscribe_failures_basic() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60 * 60 * 1000, // 1hr + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, + }; + + let lighthouse = Lighthouse::new(opt).await?; + let lighthouse_task = tokio::spawn(lighthouse.clone().run()); + + let mut client = lighthouse_client_new(lighthouse.address()).await?; + let request = tonic::Request::new(SubscribeFailuresRequest {}); + client.subscribe_failures(request).await?; + + lighthouse_task.abort(); + Ok(()) + } + + #[tokio::test] + async fn test_subscribe_failures_delivers_notifications() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60 * 60 * 1000, + quorum_tick_ms: 10, + heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, + }; + let lighthouse = Lighthouse::new(opt).await?; + let mut client = lighthouse_client_new(lighthouse.address()).await?; + let lighthouse_task = tokio::spawn(lighthouse.clone().run()); + + // 1. Subscribe with a deadline + let mut req = tonic::Request::new(SubscribeFailuresRequest {}); + req.set_timeout(Duration::from_secs(5)); + let mut stream = client.subscribe_failures(req).await?.into_inner(); + + // 2. Trigger a failure notification + { + let state = lighthouse.state.lock().await; + state + .failure_channel + .send(FailureNotification { + replica_id: "replica_id_X".into(), + error_message: "injected failure".to_string(), + }) + .unwrap(); + } + + // 3. Ensure we receive it + match stream.next().await { + Some(Ok(note)) => { + assert_eq!(note.replica_id, "replica_id_X"); + assert_eq!(note.error_message, "injected failure"); + } + other => panic!("Expected notification, got {:?}", other), + } + + lighthouse_task.abort(); + Ok(()) + } + + #[tokio::test] + async fn test_config_broadcasting() -> Result<()> { + // Create a test config file + let test_config = + r#"{"learning_rate": "0.001", "batch_size": "32", "model_type": "transformer"}"#; + std::fs::write("test_config_temp.json", test_config).unwrap(); + + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60000, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: Some("test_config_temp.json".to_string()), + }; + + let lighthouse = Lighthouse::new(opt).await?; + let lighthouse_task = tokio::spawn(lighthouse.clone().run()); + + let mut client = lighthouse_client_new(lighthouse.address()).await?; + + let request = tonic::Request::new(LighthouseConfigRequest {}); + let response = client.get_config(request).await?; + let config_data = response.into_inner().config_data; + + // Verify the config was loaded and parsed correctly + assert_eq!(config_data.get("learning_rate"), Some(&"0.001".to_string())); + assert_eq!(config_data.get("batch_size"), Some(&"32".to_string())); + assert_eq!( + config_data.get("model_type"), + Some(&"transformer".to_string()) + ); + assert_eq!(config_data.len(), 3); + + lighthouse_task.abort(); + + // Clean up test file + std::fs::remove_file("test_config_temp.json").unwrap(); + + Ok(()) + } + + #[tokio::test] + async fn test_config_broadcasting_no_config() -> Result<()> { + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60000, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, + }; + + let lighthouse = Lighthouse::new(opt).await?; + let lighthouse_task = tokio::spawn(lighthouse.clone().run()); + + let mut client = lighthouse_client_new(lighthouse.address()).await?; + + let request = tonic::Request::new(LighthouseConfigRequest {}); + let response = client.get_config(request).await?; + let config_data = response.into_inner().config_data; + + // When no config is provided, should return empty map + assert_eq!(config_data.len(), 0); + + lighthouse_task.abort(); + + Ok(()) + } + + #[tokio::test] + async fn test_config_broadcasting_invalid_json() -> Result<()> { + // Create an invalid JSON file + let invalid_json = r#"{"learning_rate": "0.001", "batch_size": 32 // invalid comment"#; + std::fs::write("test_invalid_config.json", invalid_json).unwrap(); + + let opt = LighthouseOpt { + min_replicas: 1, + bind: "[::]:0".to_string(), + join_timeout_ms: 60000, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: Some("test_invalid_config.json".to_string()), + }; + + let lighthouse = Lighthouse::new(opt).await?; + let lighthouse_task = tokio::spawn(lighthouse.clone().run()); + + let mut client = lighthouse_client_new(lighthouse.address()).await?; + + let request = tonic::Request::new(LighthouseConfigRequest {}); + let response = client.get_config(request).await?; + let config_data = response.into_inner().config_data; + + // When invalid JSON is provided, should return empty map + assert_eq!(config_data.len(), 0); + + lighthouse_task.abort(); + + // Clean up test file + std::fs::remove_file("test_invalid_config.json").unwrap(); + + Ok(()) + } } diff --git a/src/manager.rs b/src/manager.rs index e28cbeb..3463554 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -15,6 +15,7 @@ use tokio::sync::broadcast; use tokio::sync::Mutex; use tokio::task::JoinSet; use tokio::time::sleep; +use tokio::time::timeout as tokio_timeout; use tonic::transport::server::TcpIncoming; use tonic::transport::Channel; use tonic::transport::Server; @@ -54,7 +55,7 @@ macro_rules! info_with_replica { struct ManagerState { checkpoint_metadata: HashMap, - channel: broadcast::Sender, + quorum_channel: broadcast::Sender, participants: HashMap, should_commit_channel: broadcast::Sender, @@ -126,7 +127,7 @@ impl Manager { heartbeat_interval: heartbeat_interval, state: Mutex::new(ManagerState { checkpoint_metadata: HashMap::new(), - channel: tx, + quorum_channel: tx, participants: HashMap::new(), should_commit_channel: should_commit_tx, @@ -204,7 +205,7 @@ impl Manager { }); lighthouse_request.set_timeout(timeout); - let response = tokio::time::timeout(timeout, client.quorum(lighthouse_request)) + let response = tokio_timeout(timeout, client.quorum(lighthouse_request)) .await .unwrap_or_else(|e| { Err(Status::cancelled(format!( @@ -217,7 +218,7 @@ impl Manager { info_with_replica!(self.replica_id, "got lighthouse quorum {:?}", resp); state - .channel + .quorum_channel .send( resp.quorum .ok_or_else(|| Status::internal("missing quorum"))?, @@ -273,7 +274,7 @@ impl ManagerService for Arc { }; // TODO check step state.participants.insert(group_rank, member.clone()); - let rx = state.channel.subscribe(); + let rx = state.quorum_channel.subscribe(); self._run_quorum(&mut state, member, timeout).await?; @@ -550,6 +551,8 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -597,6 +600,8 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -652,6 +657,8 @@ mod tests { min_replicas: 2, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); @@ -724,6 +731,8 @@ mod tests { min_replicas: 1, quorum_tick_ms: 100, heartbeat_timeout_ms: 5000, + failure_tick_ms: 1000, + lighthouse_config: None, }) .await?; let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); diff --git a/templates/config.html b/templates/config.html new file mode 100644 index 0000000..b316427 --- /dev/null +++ b/templates/config.html @@ -0,0 +1,293 @@ + + + + Configuration - Lighthouse Dashboard + + + + + + +
+

Configuration - Lighthouse Dashboard

+ +
+ +
+

Configuration

+ + {% match success_message %} + {% when Some with (message) %} +
+ ✓ Success: {{message}} +
+ {% when None %} + {% endmatch %} + +
+

Current Configuration

+
Last updated: {{timestamp}}
+
{{config_data}}
+
+ +
+

Update Configuration

+
+ Enter your configuration as JSON ("{key: value}"). The format will be automatically validated and formatted as you type. +
+ +
+ + +
+ + +
+
+
+ +
+ Note: Configuration changes will be applied immediately to the lighthouse state. + The interface will automatically validate JSON syntax and provide helpful error messages. +
+
+ + + + + \ No newline at end of file diff --git a/templates/index.html b/templates/index.html index 000e909..672b4a4 100644 --- a/templates/index.html +++ b/templates/index.html @@ -52,7 +52,12 @@

Lighthouse Dashboard - torchft

- +
diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index 9614d1b..4a83852 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -11,8 +11,8 @@ class ManagerClient: checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, - commit_failures: int, init_sync: bool = True, + commit_failures: int = 0, ) -> QuorumResult: ... def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( @@ -60,9 +60,12 @@ class LighthouseServer: join_timeout_ms: Optional[int] = None, quorum_tick_ms: Optional[int] = None, heartbeat_timeout_ms: Optional[int] = None, + failure_tick_ms: Optional[int] = None, ) -> None: ... def address(self) -> str: ... def shutdown(self) -> None: ... + def inject_failure(self, replica_id: str) -> None: ... + def get_config(self, timeout: timedelta) -> dict[str, str]: ... @dataclass class QuorumMember: @@ -85,6 +88,14 @@ class Quorum: participants: List[QuorumMember] created: Timestamp +@dataclass +class FailureNotification: + replica_id: str + +class FailureStream: + def __iter__(self) -> "FailureStream": ... + def __next__(self) -> FailureNotification: ... + @dataclass class LighthouseClient: addr: str @@ -106,3 +117,7 @@ class LighthouseClient: replica_id: str, timeout: timedelta = timedelta(seconds=5), ) -> None: ... + def subscribe_failures( + self, + timeout: timedelta = timedelta(seconds=5), + ) -> FailureStream: ... diff --git a/torchft/data.py b/torchft/data.py index 02e5b3b..77ec1de 100644 --- a/torchft/data.py +++ b/torchft/data.py @@ -38,15 +38,15 @@ class DistributedSampler(data.distributed.DistributedSampler): This will shard the input dataset into ``num_replicas*num_replica_group`` number of shards. - Each shard rank is calculated via: ``rank + num_replicas*replica_rank`` + Each shard rank is calculated via: ``rank + num_replicas*replica_group_id`` - num_replicas and replica_rank must be the same on all workers. + num_replicas and replica_group_id must be the same on all workers. """ def __init__( self, dataset: data.Dataset, - replica_rank: int, + replica_group_id: int, num_replica_groups: int, group_rank: Optional[int] = None, num_replicas: Optional[int] = None, @@ -55,7 +55,7 @@ def __init__( """ Args: data: the dataset to use - replica_rank: the group ID (0-num_replica_groups) to use for this shard of data. + replica_group_id: the group ID (0-num_replica_groups) to use for this shard of data. num_replica_groups: the max number of global replica groups rank: the local group rank num_replicas: the local group world size @@ -65,7 +65,7 @@ def __init__( if num_replicas is None: num_replicas = dist.get_world_size() - self.global_rank: int = group_rank + num_replicas * replica_rank + self.global_rank: int = group_rank + num_replicas * replica_group_id self.global_world_size: int = num_replicas * num_replica_groups super().__init__( diff --git a/torchft/data_test.py b/torchft/data_test.py index 8dae190..5b7c6b6 100644 --- a/torchft/data_test.py +++ b/torchft/data_test.py @@ -27,7 +27,7 @@ def test_distributed_sampler(self) -> None: dataset = DummyDataset(1000) sampler = DistributedSampler( dataset, - replica_rank=1, + replica_group_id=1, num_replica_groups=2, group_rank=3, num_replicas=4, diff --git a/torchft/lighthouse_test.py b/torchft/lighthouse_test.py index 067a622..94f8df9 100644 --- a/torchft/lighthouse_test.py +++ b/torchft/lighthouse_test.py @@ -155,3 +155,147 @@ def test_heartbeat_round_trip(self) -> None: finally: lighthouse.shutdown() + + def test_subscribe_failures(self) -> None: + """Test that subscribe_failures can be called without raising an exception.""" + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + ) + try: + client = LighthouseClient( + addr=lighthouse.address(), + connect_timeout=timedelta(seconds=1), + ) + stream = client.subscribe_failures(timeout=timedelta(milliseconds=100)) + finally: + lighthouse.shutdown() + + def test_subscribe_failures_notification(self) -> None: + """Test that failure notifications are delivered to subscribers.""" + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + ) + try: + client = LighthouseClient( + addr=lighthouse.address(), + connect_timeout=timedelta(seconds=1), + ) + stream = client.subscribe_failures(timeout=timedelta(seconds=1)) + lighthouse.inject_failure("nodeX") + note = next(stream) + assert note.replica_id == "nodeX" + finally: + lighthouse.shutdown() + + def test_inject_failure(self) -> None: + """Test that inject failure delivers a failure notification to subscribers""" + # Start a lighthouse server + server = LighthouseServer( + bind="[::]:0", + min_replicas=1, + join_timeout_ms=100, + ) + print(f"Server address: {server.address()}") + + # Create a client to subscribe to failures + client = LighthouseClient(server.address(), timedelta(seconds=5)) + failure_stream = client.subscribe_failures(timedelta(seconds=5)) + + # Inject a failure + replica_id = "test_replica" + print(f"Injecting failure for replica: {replica_id}") + server.inject_failure(replica_id) + + # Wait a bit for the notification to be processed + time.sleep(1) + + # Try to get the failure notification + try: + notification = next(failure_stream) + print( + f"Received failure notification for replica: {notification.replica_id}" + ) + assert notification.replica_id == replica_id, "Received wrong replica_id" + print("Test passed!") + except Exception as e: + print(f"Error: {e}") + + # Clean up + server.shutdown() + + def test_get_config_rpc(self) -> None: + """Test that get_config RPC returns configuration data.""" + import json + import os + import tempfile + + # Create a temporary config file with test data + test_config = { + "learning_rate": "0.001", + "batch_size": "64", + "model_type": "transformer", + "epochs": "100", + } + + with tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False) as f: + json.dump(test_config, f) + config_file_path = f.name + + try: + # Test without config file first + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + ) + + client = LighthouseClient( + addr=lighthouse.address(), + connect_timeout=timedelta(seconds=1), + ) + + # Test that get_config method exists and can be called + config_dict = client.get_config(timeout=timedelta(seconds=1)) + + # Verify the returned type is a dictionary + assert isinstance( + config_dict, dict + ), f"Expected dict, got {type(config_dict)}" + + # With no config file provided to LighthouseServer, should return empty dict + assert len(config_dict) == 0, f"Expected empty config, got {config_dict}" + + lighthouse.shutdown() + + # Test with config file using the now exposed lighthouse_config parameter + lighthouse_with_config = LighthouseServer( + bind="[::]:0", min_replicas=1, lighthouse_config=config_file_path + ) + + client_with_config = LighthouseClient( + addr=lighthouse_with_config.address(), + connect_timeout=timedelta(seconds=1), + ) + + # Get config from lighthouse that has a config file + config_dict_with_file = client_with_config.get_config( + timeout=timedelta(seconds=1) + ) + + # Verify the config was loaded correctly + assert isinstance( + config_dict_with_file, dict + ), f"Expected dict, got {type(config_dict_with_file)}" + assert config_dict_with_file["learning_rate"] == "0.001" + assert config_dict_with_file["batch_size"] == "64" + assert config_dict_with_file["model_type"] == "transformer" + assert config_dict_with_file["epochs"] == "100" + assert len(config_dict_with_file) == 4 + + lighthouse_with_config.shutdown() + + finally: + # Clean up the temporary config file + if os.path.exists(config_file_path): + os.unlink(config_file_path) diff --git a/torchft/manager.py b/torchft/manager.py index 2c1c640..ae48cfd 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -24,25 +24,29 @@ and Hybrid FSDP. """ - import concurrent.futures import logging +import multiprocessing import os import socket +import threading +import time import traceback import uuid from concurrent.futures import ThreadPoolExecutor from contextlib import nullcontext from datetime import timedelta from enum import Enum +from multiprocessing.connection import Connection from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast import torch from torch.distributed import ReduceOp, TCPStore -from torchft._torchft import ManagerClient, ManagerServer +from torchft._torchft import LighthouseClient, ManagerClient, ManagerServer from torchft.checkpointing import CheckpointTransport, HTTPTransport from torchft.futures import future_timeout +from torchft.multiprocessing import _MonitoredPipe if TYPE_CHECKING: from torchft.process_group import ProcessGroup @@ -103,6 +107,7 @@ def __init__( timeout: timedelta = timedelta(seconds=60), quorum_timeout: timedelta = timedelta(seconds=60), connect_timeout: timedelta = timedelta(seconds=60), + proactive_recovery_subscribe_timeout: timedelta = timedelta(milliseconds=100), rank: Optional[int] = None, world_size: Optional[int] = None, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, @@ -116,6 +121,7 @@ def __init__( checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None, init_sync: bool = True, max_retries: Optional[int] = None, + proactive_recovery: bool = False, ) -> None: """ Args: @@ -166,6 +172,9 @@ def __init__( self._timeout = timeout self._quorum_timeout = quorum_timeout self._connect_timeout = connect_timeout + self._proactive_recovery_subscribe_timeout = ( + proactive_recovery_subscribe_timeout + ) self._replica_world_size_mode = world_size_mode self._init_sync = init_sync self._max_retries = max_retries @@ -187,9 +196,7 @@ def __init__( self._checkpoint_transport: CheckpointTransport[Dict[str, T]] = ( checkpoint_transport ) - self._executor = ThreadPoolExecutor( - max_workers=1, thread_name_prefix="async_quorum" - ) + self._executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="") self._quorum_future: Optional[concurrent.futures.Future] = None self._store = TCPStore( @@ -205,12 +212,57 @@ def __init__( torch.cuda.Stream() if torch.cuda.is_available() else None ) + lighthouse_addr: Optional[str] = lighthouse_addr + if os.environ.get("TORCHFT_LIGHTHOUSE") is not None: + lighthouse_addr = ( + lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"] + ) # Else error in tests, since TORCHFT_LIGHTHOUSE may not be set + + self._proactive_recovery = proactive_recovery or int( + os.environ.get("TORCHFT_PROACTIVE_RECOVERY", 0) + ) + + if lighthouse_addr is not None and self._proactive_recovery: + ctx = multiprocessing.get_context("spawn") + error_local, error_remote = ctx.Pipe() + self._error_pipe = _MonitoredPipe(error_local) + self._error_remote = _MonitoredPipe(error_remote) + self._failure_listener_stop_event = ctx.Event() + + self._failure_listener_process = ctx.Process( + target=_failure_listener_process_main, + args=( + lighthouse_addr, + self._connect_timeout, + self._failure_listener_stop_event, + error_remote, + self._proactive_recovery_subscribe_timeout, + ), + daemon=True, + ) + self._failure_listener_process.start() + else: + self._failure_listener_process = None + self._error_pipe = None + self._failure_listener_stop_event = None + + # Initialize and start the error processing thread if the listener process is active + self._error_processor_thread: Optional[threading.Thread] = None + self._error_processor_stop_event: Optional[threading.Event] = None + if self._failure_listener_process is not None: + self._error_processor_stop_event = threading.Event() + self._error_processor_thread = threading.Thread( + target=self._error_processor_loop, + name="TorchFTErrorProcessor", + daemon=True, + ) + self._error_processor_thread.start() + if self._group_rank == 0: if port is None: port = int(os.environ.get(MANAGER_PORT_ENV, 0)) bind = f"[::]:{port}" - lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"] # We need a unique identifier in the case that a worker restarts quickly and # replaces the previous worker with the same ID. @@ -219,6 +271,7 @@ def __init__( replica_id = new_uuid else: replica_id = f"{replica_id}:{new_uuid}" + self._manager = ManagerServer( replica_id=replica_id, lighthouse_addr=lighthouse_addr, @@ -229,13 +282,11 @@ def __init__( heartbeat_interval=heartbeat_interval, connect_timeout=connect_timeout, ) - self._store.set(MANAGER_ADDR_KEY, self._manager.address()) self._store.set(REPLICA_ID_KEY, replica_id) addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8") self._client = ManagerClient(addr, connect_timeout=connect_timeout) - replica_id = self._store.get(REPLICA_ID_KEY).decode("utf-8") self._logger = _ManagerLogger( manager=self, replica_id=replica_id or "", group_rank=group_rank @@ -258,13 +309,96 @@ def set_state_dict_fns( self._load_state_dict = load_state_dict self._user_state_dict = state_dict + def _error_handler(self, err): + self._logger.info(f"Received error: {err}") + self.report_error(err) + self._pg.abort() + + def _error_processor_loop(self) -> None: + """Continuously checks the error pipe from the listener process and reports errors.""" + assert ( + self._error_pipe is not None + ), "Error pipe must be initialized for error processor loop." + assert ( + self._error_processor_stop_event is not None + ), "Stop event must be initialized for error processor loop." + + try: + while not self._error_processor_stop_event.is_set(): + try: + item = self._error_pipe.recv(0.1) + except TimeoutError: + continue + except OSError: + break + except Exception as e: + self._error_handler(e) + finally: + pass + def shutdown(self, wait: bool = True) -> None: """ Shutdown the manager and checkpoint server. """ - self._checkpoint_transport.shutdown(wait=wait) if self._manager is not None: self._manager.shutdown() + + # Stop the error processor thread first + if ( + self._error_processor_thread is not None + and self._error_processor_stop_event is not None + ): + self._logger.info("Setting error processor thread stop event") + self._error_processor_stop_event.set() + if wait: + self._logger.info("Waiting for error processor thread to complete") + try: + self._error_processor_thread.join(timeout=5) # Short timeout + if self._error_processor_thread.is_alive(): + self._logger.warn( + "Error processor thread did not terminate in time." + ) + else: + self._logger.info("Error processor thread shutdown completed.") + except Exception as e: + self._logger.warn(f"Error waiting for error processor thread: {e}") + + # Stop the failure listener process if it exists + if ( + hasattr(self, "_failure_listener_process") + and self._failure_listener_process is not None + ): + self._logger.info("Setting failure listener stop event for process") + if ( + hasattr(self, "_failure_listener_stop_event") + and self._failure_listener_stop_event is not None + ): + self._failure_listener_stop_event.set() + + if wait: + self._logger.info("Waiting for failure listener process to complete") + try: + self._failure_listener_process.join(timeout=10) # Process join + if self._failure_listener_process.is_alive(): + self._logger.warn( + "Failure listener process did not terminate, attempting to terminate." + ) + self._failure_listener_process.terminate() # Force terminate if join times out + self._failure_listener_process.join( + timeout=1 + ) # Wait for terminate + else: + self._logger.info("Failure listener process shutdown completed") + except Exception as e: + self._logger.warn( + f"Error waiting for/terminating failure listener process: {e}" + ) + + # Clean up pipe + if hasattr(self, "_error_pipe") and self._error_pipe is not None: + self._error_pipe.close() + + self._checkpoint_transport.shutdown(wait=wait) self._executor.shutdown(wait=wait) def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]: @@ -824,3 +958,60 @@ def warn(self, msg: str) -> None: def exception(self, msg: str) -> None: self._logger.exception(f"{self.prefix()} {msg}") + + +def _failure_listener_process_main( + lighthouse_addr_str: Optional[str], + connect_timeout: timedelta, + stop_event: multiprocessing.Event, + error_pipe: Connection, + subscribe_timeout: timedelta = timedelta(milliseconds=100), +): + """ + Background process that monitors lighthouse for failures through gRPC stream (with an iterator interface) and reports them via error_pipe. + """ + if not lighthouse_addr_str: + return + + while not stop_event.is_set(): + try: + lighthouse_client = LighthouseClient( + lighthouse_addr_str, connect_timeout=connect_timeout + ) + stream = lighthouse_client.subscribe_failures(timeout=subscribe_timeout) + while not stop_event.is_set(): + try: + note = next( + stream + ) # This will block until a new item or timeout if stream supports it + if note: + if stop_event.is_set(): + break + error = Exception( + f"Peer failure detected in listener process: replica {note.replica_id} has failed" + ) + error_pipe.send(ExceptionWithTraceback(error)) + except StopIteration: + # Stream has ended, break out to outer loop to reconnect + if not stop_event.is_set(): + logging.warning( + "Failure Listener: Stream ended unexpectedly, attempting to reconnect..." + ) + break # Break the inner loop to reconnect + else: + break + except Exception as e_stream: + if not stop_event.is_set(): + continue # Break due to subscribe_timeout. Allows the process to check stop_event again. + else: + break + if stop_event.is_set(): + break + time.sleep(0.01) # Prevent CPU thrashing + except Exception as e_outer: + if not stop_event.is_set(): + logging.warning( + f"Failure Listener: Connection error: {e_outer}, retrying in 1 second..." + ) + time.sleep(1) + pass diff --git a/torchft/manager_test.py b/torchft/manager_test.py index bb058e4..2fb0373 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -5,18 +5,34 @@ # LICENSE file in the root directory of this source tree. import concurrent +import multiprocessing +import time +from dataclasses import dataclass from datetime import timedelta from typing import Optional from unittest import TestCase from unittest.mock import MagicMock, create_autospec, patch import torch +import torch.distributed as dist from torch.distributed import TCPStore -from torchft._torchft import QuorumResult +from torchft._torchft import ( + FailureStream, + LighthouseClient, + LighthouseServer, + QuorumResult, +) from torchft.checkpointing.transport import CheckpointTransport -from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode -from torchft.process_group import ProcessGroup, _DummyWork +from torchft.manager import ( + MANAGER_ADDR_KEY, + REPLICA_ID_KEY, + ExceptionWithTraceback, + Manager, + WorldSizeMode, + _failure_listener_process_main, +) +from torchft.process_group import ProcessGroup, ProcessGroupGloo, _DummyWork def mock_should_commit( @@ -43,6 +59,7 @@ def _create_manager( timeout: timedelta = timedelta(seconds=10), init_sync: bool = True, max_retries: Optional[int] = None, + proactive_recovery: bool = False, ) -> Manager: pg = create_autospec(ProcessGroup) pg.errored.return_value = None @@ -72,6 +89,7 @@ def _create_manager( timeout=timeout, init_sync=init_sync, max_retries=max_retries, + proactive_recovery=proactive_recovery, ) self.manager = manager return manager @@ -773,3 +791,127 @@ def test_max_retries(self, client_mock: MagicMock) -> None: # This should succeed and reset the counter self.assertTrue(manager.should_commit()) self.assertEqual(manager._commit_failures, 0) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_manager_error_handler(self, client_mock: MagicMock) -> None: + """Test that the Manager correctly processes exceptions sent from the failure_listener_process.""" + # Create a manager + manager = self._create_manager() + + # Create an exception simulating what would be sent from _failure_listener_process_main + error = Exception("Peer failure detected: replica failed_replica has failed") + exception = ExceptionWithTraceback(error) + + # Directly test the error handling mechanism + manager._error_handler(error) + + # Verify the error was properly processed + captured_error = manager.errored() + self.assertIsNotNone(captured_error) + self.assertEqual(str(captured_error.original_exception), str(error)) + + def test_direct_error_pipe(self) -> None: + """Test sending an exception to the Manager's _error_pipe.""" + # Create a manager with proactive_recovery=True to ensure it has an error pipe + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + join_timeout_ms=100, + ) + + # Create a manager that tries to join + store = dist.TCPStore( + host_name="localhost", + port=0, + is_master=True, + wait_for_workers=False, + ) + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=lambda x: None, + state_dict=lambda: None, + replica_id=f"lighthouse_test", + store_addr="localhost", + store_port=store.port, + rank=0, + world_size=1, + use_async_quorum=False, + lighthouse_addr=lighthouse.address(), + proactive_recovery=True, + ) + + # Make sure the error pipe is created + self.assertIsNotNone(manager._error_pipe, "Manager should have an error pipe") + time.sleep(1) + # Create a mock error message + mock_error_msg = "Test failure detected from direct pipe test" + test_exception = Exception(mock_error_msg) + + # Create an ExceptionWithTraceback and send it through the pipe + exc_with_tb = ExceptionWithTraceback(test_exception) + manager._error_remote.send(exc_with_tb) + + # Wait a short time for the error processor thread to process the message + time.sleep(1) + + # Verify that the error was properly processed by the Manager + error_obj = manager.errored() + self.assertIsNotNone( + error_obj, "Error should have been captured by the Manager" + ) + + # Clean up + manager.shutdown(wait=True) + + def test_manager_failure_e2e(self) -> None: + """Test that the Manager correctly handles errors from the failure_listener_process.""" + # Create a manager with proactive_recovery=True to ensure it has an error pipe + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=1, + join_timeout_ms=100, + ) + + # Create a manager that tries to join + store = dist.TCPStore( + host_name="localhost", + port=0, + is_master=True, + wait_for_workers=False, + ) + pg = ProcessGroupGloo() + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=lambda x: None, + state_dict=lambda: None, + replica_id=f"lighthouse_test", + store_addr="localhost", + store_port=store.port, + rank=0, + world_size=1, + use_async_quorum=False, + lighthouse_addr=lighthouse.address(), + proactive_recovery=True, + ) + + time.sleep(1.5) + + failed_replica_id = "failed_replica" + lighthouse.inject_failure(failed_replica_id) + + time.sleep(1.5) # Prevent flakyness + error_obj = manager.errored() + + # Verify that the manager received the error notification + self.assertIsNotNone(error_obj, "Manager should have captured the failure") + self.assertIn( + failed_replica_id, + str(error_obj.original_exception), + f"Error should mention the failed replica: {error_obj.original_exception}", + ) + + # Clean up resources + manager.shutdown(wait=True) diff --git a/train_ddp.py b/train_ddp.py index fd79b8a..96c2c13 100644 --- a/train_ddp.py +++ b/train_ddp.py @@ -51,7 +51,7 @@ def main() -> None: # majority of groups will be available so few batches will be dropped. sampler = DistributedSampler( trainset, - replica_group=REPLICA_GROUP_ID, + replica_group_id=REPLICA_GROUP_ID, num_replica_groups=NUM_REPLICA_GROUPS, group_rank=0, # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. diff --git a/train_ddp_proactive.py b/train_ddp_proactive.py new file mode 100644 index 0000000..3d0002c --- /dev/null +++ b/train_ddp_proactive.py @@ -0,0 +1,218 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +import sys +import time +from datetime import timedelta + +REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) +os.environ["CUDA_VISIBLE_DEVICES"] = str(REPLICA_GROUP_ID % 4) +os.environ["NCCL_HOSTID"] = str(REPLICA_GROUP_ID) + +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch import nn, optim +from torch.distributed.elastic.multiprocessing.errors import record +from torchdata.stateful_dataloader import StatefulDataLoader + +from torchft import ( + DistributedDataParallel, + DistributedSampler, + Manager, + Optimizer, + ProcessGroupGloo, + ProcessGroupNCCL, +) +from torchft.checkpointing.pg_transport import PGTransport + +logging.basicConfig(level=logging.INFO) + + +@record +def main() -> None: + REPLICA_GROUP_ID = int(os.environ.get("REPLICA_GROUP_ID", 0)) + NUM_REPLICA_GROUPS = int(os.environ.get("NUM_REPLICA_GROUPS", 2)) + + transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))] + ) + trainset = torchvision.datasets.CIFAR10( + root="./cifar", train=True, download=True, transform=transform + ) + + # This shards the training set across all ranks and replica groups. We manage + # the dataloaders on a per replica group basis with the assumption that the + # majority of groups will be available so few batches will be dropped. + sampler = DistributedSampler( + trainset, + replica_group_id=REPLICA_GROUP_ID, + num_replica_groups=NUM_REPLICA_GROUPS, + group_rank=0, + # for DDP we can use replica groups of size 1, FSDP/PP/CP would need more. + num_replicas=1, + shuffle=True, + ) + + # This uses the torchdata StatefulDataLoader to be able to checkpoint and + # restore the per worker dataloader position. + trainloader = StatefulDataLoader( + trainset, batch_size=64, num_workers=2, sampler=sampler + ) + + def load_state_dict(state_dict): + m.load_state_dict(state_dict["model"]) + optimizer.load_state_dict(state_dict["optim"]) + + def state_dict(): + return { + "model": m.state_dict(), + "optim": optimizer.state_dict(), + } + + device = "cuda" if torch.cuda.is_available() else "cpu" + pg = ( + ProcessGroupNCCL( + timeout=timedelta(seconds=30), + ) + if torch.cuda.is_available() + else ProcessGroupGloo(timeout=timedelta(seconds=5)) + ) + + transport = PGTransport( + pg, + timeout=timedelta(seconds=10), + device=("cuda" if torch.cuda.is_available() else "cpu"), + ) + + manager = Manager( + pg=pg, + min_replica_size=1, + load_state_dict=load_state_dict, + state_dict=state_dict, + replica_id=f"train_ddp_{REPLICA_GROUP_ID}", + timeout=timedelta(seconds=30), + checkpoint_transport=transport, + ) + + class Net(nn.Module): + def __init__(self): + super().__init__() + self.cnn = nn.Sequential( + nn.Conv2d(3, 6, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + nn.Conv2d(6, 16, 5), + nn.ReLU(), + nn.MaxPool2d(2, 2), + ) + + final_dim = 10 + # We add a useless 1GB intermediate layer so we spend more time in dist + # communication so injected failures are more likely to cause issues + # if they exist. + target_size = 1_000_000_000 + self.useless = nn.Embedding(target_size // final_dim // 4, final_dim) + + self.classifier = nn.Sequential( + nn.Linear(16 * 5 * 5, 120), + nn.ReLU(), + nn.Linear(120, 84), + nn.ReLU(), + nn.Linear(84, final_dim), + ) + + def forward(self, x): + x = self.cnn(x) + x = torch.flatten(x, 1) # flatten all dimensions except batch + x = self.classifier(x) + x += self.useless.weight[0] + return x + + m = Net().to(device) + m = DistributedDataParallel(manager, m) + optimizer = Optimizer(manager, optim.AdamW(m.parameters())) + + print(m) + num_params = sum(p.numel() for p in m.parameters()) + print(f"Total number of parameters: {num_params}") + + sort_by_keyword = "self_" + device + "_time_total" + + def trace_handler(p): + output = p.key_averages().table( + sort_by=sort_by_keyword, + row_limit=100, + ) + print(output) + p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json") + + # You can use an epoch based training but with faults it's easier to use step + # based training. + prof = torch.profiler.profile( + schedule=torch.profiler.schedule(wait=5, warmup=1, active=10, repeat=2), + on_trace_ready=trace_handler, + record_shapes=True, + profile_memory=True, + ) + + prof.start() + while True: + for i, (inputs, labels) in enumerate(trainloader): + prof.step() + + time.sleep(0.5) # Else each iteration runs too quickly + + inputs = inputs.to(device) + labels = labels.to(device) + + # must be called at the beginning of each train loop + # Quorum computation is triggered here but only needed in the backwards pass. + optimizer.zero_grad() + + out = m(inputs) + criterion = nn.CrossEntropyLoss() + loss = criterion(out, labels) + + # Gradient allreduce overlaps with the backwards pass. + loss.backward() + if manager.current_step() == 3: + if REPLICA_GROUP_ID == 0: + manager.shutdown() + exit(0) + # If proactive recovery, then the surviving process will reconfigure + # If not proactive recovery, then the surviving process will wait until timeout + + test_tensor = torch.tensor([1.0]).to(device) + manager.allreduce(test_tensor) + + # must be called at the end of the train loop + # This may not actually step the optimizer if an error occured during grad allreduce. + optimizer.step() + + if manager.current_step() % 100 == 0: + print(f"[{manager.current_step()}] loss = {loss.item()}") + + # TODO (by the user): periodically checkpoint model, optim, manager and dataloader + + # You typically want to checkpoint dataloader frequently (every step?) to + # avoid repeated batches as it's replica group specific. + + # Model, optim and manager checkpoints can be done more infrequently as + # they're shared across all groups and will load from existing replicas as + # long as not every worker goes down. + + if manager.current_step() >= 10000: + # complete training + prof.stop() + exit() + + +if __name__ == "__main__": + main()