Skip to content

Add config sharing from Lighthouse with UI support (#130) #202

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,16 @@ prost = "0.13.3"
prost-types = "0.13.3"
pyo3 = {version = "0.24", features = ["extension-module"]}
rand = "0.8.5"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
slog = "2.7.0"
slog-stdlog = "4.1.1"
stderrlog = "0.6.0"
structopt = "0.3.26"
tokio = {version = "1.40.0", features = ["full", "test-util", "tracing", "macros", "rt-multi-thread"] }
tokio-stream = {version = "0.1.14", features = ["sync"]}
tonic = "0.12.2"
futures-core = "0.3"

[build-dependencies]
tonic-build = "0.12.2"
Expand Down
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,38 @@ CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --mast

By observing the outputs from both shells, you should observe process group reconfiguration and live checkpoint recovery.

### Proactive Failure Recovery Mode (Experimental)

You can experiment with proactive failure recovery mode by:

```sh
export TORCHFT_PROACTIVE_RECOVERY=1
```

With this enabled, the manager will listen to the Lighthouse server for heartbeat failures of other replica groups and break from a hanging allreduce.

You can test this out by running `train_ddp_proactive.py`

On shell 1 (one replica groups starts initial training):
```sh
export REPLICA_GROUP_ID=0
export NUM_REPLICA_GROUPS=2
export TORCHFT_PROACTIVE_RECOVERY=1

CUDA_VISIBLE_DEVICES=0 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29600 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py
```

On shell 2 (a second replica group joins):
```sh
export REPLICA_GROUP_ID=1
export NUM_REPLICA_GROUPS=2
export TORCHFT_PROACTIVE_RECOVERY=1

CUDA_VISIBLE_DEVICES=1 TORCHFT_LIGHTHOUSE=http://localhost:29510 torchrun --master_port=29601 --nnodes=1 --nproc_per_node=1 -- train_ddp_proactive.py
```

You should observe that the process with replica group id 1 will exit early, and the process with replica group id 0 will quickly resume training. If the same script is ran with after setting `export TORCHFT_PROACTIVE_RECOVERY=0`, you should observe that the process with replica group id 1 will hang for dozens of seconds before continuing.

### Example Parameter Server

torchft has a fault tolerant parameter server implementation built on it's
Expand Down
21 changes: 21 additions & 0 deletions proto/torchft.proto
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,24 @@ message LighthouseHeartbeatRequest {

message LighthouseHeartbeatResponse {}

message SubscribeFailuresRequest {}

message FailureNotification {
string replica_id = 1;
string error_message = 2;
}

message LighthouseConfigRequest {}

message LighthouseConfigResponse {
map<string, string> config_data = 1;
}

service LighthouseService {
rpc Quorum (LighthouseQuorumRequest) returns (LighthouseQuorumResponse);
rpc Heartbeat (LighthouseHeartbeatRequest) returns (LighthouseHeartbeatResponse);
rpc SubscribeFailures (SubscribeFailuresRequest) returns (stream FailureNotification);
rpc GetConfig (LighthouseConfigRequest) returns (LighthouseConfigResponse);
}

message ManagerQuorumRequest {
Expand Down Expand Up @@ -126,3 +141,9 @@ service ManagerService {
rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse);
rpc Kill(KillRequest) returns (KillResponse);
}

message LighthouseClientRequest {
string replica_id = 1;
}


126 changes: 116 additions & 10 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@ mod timeout;
use anyhow::Result;
use atty::Stream;
use core::time::Duration;
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
use pyo3::exceptions::{PyRuntimeError, PyStopIteration, PyTimeoutError};
use std::cmp;
use std::env;
use std::sync::Arc;
use std::thread::available_parallelism;
use structopt::StructOpt;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
use tokio_stream::StreamExt;
use tonic::transport::Channel;
use tonic::Status;

Expand All @@ -35,11 +36,13 @@ pub mod torchftpb {
use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient;
use crate::torchftpb::manager_service_client::ManagerServiceClient;
use crate::torchftpb::{
CheckpointMetadataRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest,
ManagerQuorumRequest, ShouldCommitRequest,
CheckpointMetadataRequest, FailureNotification as ProtoFailureNotification,
LighthouseConfigRequest, LighthouseHeartbeatRequest, LighthouseQuorumRequest,
ManagerQuorumRequest, ShouldCommitRequest, SubscribeFailuresRequest,
};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PyString};
use pyo3::{PyRef, PyRefMut};

// Get the number of threads to use for the tokio runtime
fn num_threads() -> usize {
Expand Down Expand Up @@ -290,6 +293,45 @@ struct QuorumResult {
heal: bool,
}

#[pyclass(unsendable)]
struct FailureStream {
runtime: Arc<Runtime>,
stream: tonic::Streaming<ProtoFailureNotification>,
timeout: Duration,
}

#[pymethods]
impl FailureStream {
fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
slf
}
fn __next__(mut slf: PyRefMut<'_, Self>) -> PyResult<FailureNotification> {
let runtime = slf.runtime.clone();
let timeout = slf.timeout;
// borrow stream mutably for the whole async block
let fut = async { tokio::time::timeout(timeout, slf.stream.next()).await };

match runtime.block_on(fut) {
Ok(Some(Ok(note))) => Ok(FailureNotification {
replica_id: note.replica_id,
error_message: note.error_message,
}),
Ok(Some(Err(status))) => Err(StatusError(status).into()),
Ok(None) => Err(PyStopIteration::new_err(())),
Err(_) => Err(PyTimeoutError::new_err(
"Timeout waiting for failure notification",
)),
}
}
}

#[pyclass(get_all, set_all)]
#[derive(Clone)]
struct FailureNotification {
replica_id: String,
error_message: String,
}

#[pymethods]
impl QuorumResult {
#[new]
Expand Down Expand Up @@ -478,7 +520,7 @@ fn convert_quorum(py: Python, q: &torchftpb::Quorum) -> PyResult<Quorum> {
#[pyclass]
struct LighthouseClient {
client: LighthouseServiceClient<Channel>,
runtime: Runtime,
runtime: Arc<Runtime>,
}

#[pymethods]
Expand All @@ -487,11 +529,13 @@ impl LighthouseClient {
#[new]
fn new(py: Python<'_>, addr: String, connect_timeout: Duration) -> PyResult<Self> {
py.allow_threads(move || {
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_threads())
.thread_name("torchft-lhclnt")
.enable_all()
.build()?;
let runtime = Arc::new(
tokio::runtime::Builder::new_multi_thread()
.worker_threads(num_threads())
.thread_name("torchft-lhclnt")
.enable_all()
.build()?,
);
let client = runtime
.block_on(manager::lighthouse_client_new(addr, connect_timeout))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Expand Down Expand Up @@ -586,6 +630,44 @@ impl LighthouseClient {
Ok(())
})
}

#[pyo3(signature = (timeout = Duration::from_secs(5)))]
fn subscribe_failures(&self, py: Python<'_>, timeout: Duration) -> PyResult<FailureStream> {
py.allow_threads(move || {
let req = tonic::Request::new(SubscribeFailuresRequest {});
let response = self
.runtime
.block_on(self.client.clone().subscribe_failures(req))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
Ok(FailureStream {
runtime: self.runtime.clone(),
stream: response.into_inner(),
timeout: timeout,
})
})
}

/// Get configuration data from the lighthouse server.
///
/// Returns:
/// dict[str, str]: The configuration data as a dictionary.
///
/// Args:
/// timeout (timedelta, optional): Per-RPC deadline. Default = 5 s.
#[pyo3(signature = (timeout = Duration::from_secs(5)))]
fn get_config(
&self,
py: Python<'_>,
timeout: Duration,
) -> Result<std::collections::HashMap<String, String>, StatusError> {
py.allow_threads(move || {
let mut req = tonic::Request::new(LighthouseConfigRequest {});
req.set_timeout(timeout);
let response = self.runtime.block_on(self.client.clone().get_config(req))?;
let config_data = response.into_inner().config_data;
Ok(config_data)
})
}
}

/// LighthouseServer is a GRPC server for the lighthouse service.
Expand All @@ -601,6 +683,7 @@ impl LighthouseClient {
/// join_timeout_ms (int): The timeout for joining the quorum.
/// quorum_tick_ms (int): The interval at which the quorum is checked.
/// heartbeat_timeout_ms (int): The timeout for heartbeats.
/// lighthouse_config (str, optional): Path to configuration file (JSON format).
#[pyclass]
struct LighthouseServer {
lighthouse: Arc<lighthouse::Lighthouse>,
Expand All @@ -610,7 +693,7 @@ struct LighthouseServer {

#[pymethods]
impl LighthouseServer {
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None))]
#[pyo3(signature = (bind, min_replicas, join_timeout_ms=None, quorum_tick_ms=None, heartbeat_timeout_ms=None, failure_tick_ms=None, lighthouse_config=None))]
#[new]
fn new(
py: Python<'_>,
Expand All @@ -619,10 +702,13 @@ impl LighthouseServer {
join_timeout_ms: Option<u64>,
quorum_tick_ms: Option<u64>,
heartbeat_timeout_ms: Option<u64>,
failure_tick_ms: Option<u64>,
lighthouse_config: Option<String>,
) -> PyResult<Self> {
let join_timeout_ms = join_timeout_ms.unwrap_or(100);
let quorum_tick_ms = quorum_tick_ms.unwrap_or(100);
let heartbeat_timeout_ms = heartbeat_timeout_ms.unwrap_or(5000);
let failure_tick_ms = failure_tick_ms.unwrap_or(1000);

py.allow_threads(move || {
let rt = tokio::runtime::Builder::new_multi_thread()
Expand All @@ -638,6 +724,8 @@ impl LighthouseServer {
join_timeout_ms: join_timeout_ms,
quorum_tick_ms: quorum_tick_ms,
heartbeat_timeout_ms: heartbeat_timeout_ms,
failure_tick_ms: failure_tick_ms,
lighthouse_config: lighthouse_config,
}))
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

Expand All @@ -663,6 +751,22 @@ impl LighthouseServer {
self.handle.abort();
})
}

/// inject_failure broadcasts a failure notification for the given replica.
///
/// This helper is intended for testing `subscribe_failures` from Python.
#[pyo3(signature = (replica_id))]
fn inject_failure(&self, py: Python<'_>, replica_id: String) {
let lighthouse = self.lighthouse.clone();
let runtime = &self._runtime;
py.allow_threads(move || {
let _ = runtime.block_on(async {
if let Err(e) = lighthouse.inject_failure(replica_id).await {
eprintln!("Failed to inject failure: {}", e);
}
});
});
}
}

struct StatusError(Status);
Expand Down Expand Up @@ -750,6 +854,8 @@ fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<LighthouseServer>()?;
m.add_class::<LighthouseClient>()?;
m.add_class::<QuorumResult>()?;
m.add_class::<FailureNotification>()?;
m.add_class::<FailureStream>()?;
m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?;

Ok(())
Expand Down
Loading