diff --git a/Cargo.toml b/Cargo.toml index e8d6b8a..c6fd390 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,12 +6,15 @@ edition = "2021" [dependencies] anyhow = "1.0.89" askama = "0.12.1" +atty = "0.2.14" axum = "0.7.7" +chrono = "0.4.40" +fern = {version = "0.7.1", features = ["colored"]} gethostname = "0.5.0" log = "0.4.22" prost = "0.13.3" prost-types = "0.13.3" -pyo3 = {version="0.22.3", features = ["extension-module"]} +pyo3 = {version = "0.22.3", features = ["extension-module"]} rand = "0.8.5" slog = "2.7.0" slog-stdlog = "4.1.1" diff --git a/src/lib.rs b/src/lib.rs index 59fad6a..2d4de57 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -10,18 +10,22 @@ mod net; mod retry; mod timeout; +use anyhow::Result; +use atty::Stream; use core::time::Duration; +use pyo3::exceptions::{PyRuntimeError, PyTimeoutError}; use std::env; use std::sync::Arc; - -use anyhow::Result; -use pyo3::exceptions::{PyRuntimeError, PyTimeoutError}; use structopt::StructOpt; use tokio::runtime::Runtime; use tokio::task::JoinHandle; use tonic::transport::Channel; use tonic::Status; +use chrono::Local; +use fern::colors::{Color, ColoredLevelConfig}; +use log::LevelFilter; + pub mod torchftpb { tonic::include_proto!("torchft"); } @@ -397,20 +401,50 @@ impl From for StatusError { } } +fn setup_logging() -> Result<(), Box> { + // Check if stderr is a terminal + let is_terminal = atty::is(Stream::Stderr); + let colors = ColoredLevelConfig::new() + .error(Color::Red) + .warn(Color::Yellow) + .info(Color::Green) + .debug(Color::Blue) + .trace(Color::Magenta); + let level_filter = match env::var("RUST_LOG").as_deref() { + Ok("error") => LevelFilter::Error, + Ok("warn") => LevelFilter::Warn, + Ok("info") => LevelFilter::Info, + Ok("debug") => LevelFilter::Debug, + Ok("trace") => LevelFilter::Trace, + _ => LevelFilter::Info, + }; + fern::Dispatch::new() + .format(move |out, message, record| { + let module_path = record.module_path().unwrap_or(""); + // If stderr is a terminal, use colors when printing log level, otherwise use plain text + let level = if is_terminal { + colors.color(record.level()).to_string() + } else { + record.level().to_string() + }; + out.finish(format_args!( + "{} [{}] [{}] - {}", + Local::now().format("%Y-%m-%dT%H:%M:%S%.3f"), + level, + module_path, + message + )) + }) + .level(level_filter) + .chain(std::io::stderr()) + .apply()?; + Ok(()) +} + #[pymodule] fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { // setup logging on import - let mut log = stderrlog::new(); - log.verbosity(2) - .show_module_names(true) - .timestamp(stderrlog::Timestamp::Millisecond); - - if env::var("CLICOLOR_FORCE").is_ok() { - log.color(stderrlog::ColorChoice::AlwaysAnsi); - } - - log.init() - .map_err(|e| PyRuntimeError::new_err(e.to_string()))?; + setup_logging().map_err(|e| PyRuntimeError::new_err(e.to_string()))?; m.add_class::()?; m.add_class::()?; diff --git a/src/lighthouse.rs b/src/lighthouse.rs index 05053a1..63c339a 100644 --- a/src/lighthouse.rs +++ b/src/lighthouse.rs @@ -8,6 +8,7 @@ use core::net::SocketAddr; use std::collections::HashMap; use std::collections::HashSet; use std::sync::Arc; + use std::time::Duration; use std::time::{Instant, SystemTime}; @@ -61,6 +62,25 @@ pub struct Lighthouse { opt: LighthouseOpt, listener: Mutex>, local_addr: SocketAddr, + change_logger: ChangeLogger, +} + +struct ChangeLogger { + last_reason: std::sync::Mutex>, +} +impl ChangeLogger { + fn new() -> Self { + ChangeLogger { + last_reason: std::sync::Mutex::new(None), + } + } + fn log_if_changed(&self, reason: &str) { + let mut last_reason = self.last_reason.lock().unwrap(); + if last_reason.as_deref() != Some(reason) { + info!("Quorum status: {}", reason); + *last_reason = Some(reason.to_string()); + } + } } #[derive(StructOpt, Debug)] @@ -257,12 +277,13 @@ impl Lighthouse { opt: opt, local_addr: listener.local_addr()?, listener: Mutex::new(Some(listener)), + change_logger: ChangeLogger::new(), })) } fn _quorum_tick(self: Arc, state: &mut State) -> Result<()> { let (quorum_met, reason) = quorum_compute(Instant::now(), state, &self.opt); - info!("Next quorum status: {}", reason); + self.change_logger.log_if_changed(&reason); if quorum_met.is_some() { let participants = quorum_met.unwrap(); @@ -448,7 +469,10 @@ impl LighthouseService for Arc { .requester .ok_or_else(|| return Status::invalid_argument("missing requester"))?; - info!("got quorum request for replica {}", &requester.replica_id); + info!( + "Received quorum request for replica {}", + &requester.replica_id + ); let mut rx = { let mut state = self.state.lock().await; diff --git a/src/manager.rs b/src/manager.rs index 931e995..bd14783 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -37,6 +37,21 @@ use log::{info, warn}; #[cfg(test)] use std::{println as info, println as warn}; +// The replica_id string is of the form {replica_name}:{uuid} or just {uuid} (see torchft/manager.py) +// We can parse the replica_id if it exists, otherwise we just use the uuid +macro_rules! info_with_replica { + ($replica_id:expr, $($arg:tt)*) => {{ + let parts: Vec<&str> = $replica_id.splitn(2, ':').collect(); + let formatted_message = if parts.len() == 2 { + // If there are two parts, use the replica name + info!("[Replica {}] {}", parts[0], format!($($arg)*)) + } else { + // Otherwise, just use the UUID + info!("[Replica {}] {}", $replica_id, format!($($arg)*)) + }; + }}; +} + struct ManagerState { checkpoint_metadata: HashMap, channel: broadcast::Sender, @@ -63,7 +78,10 @@ pub async fn manager_client_new( addr: String, connect_timeout: Duration, ) -> Result> { - info!("ManagerClient: establishing connection to {}", &addr); + info!( + "Creating ManagerClient: establishing connection to {}", + &addr + ); let conn = connect(addr, connect_timeout).await?; Ok(ManagerServiceClient::new(conn)) } @@ -72,7 +90,10 @@ pub async fn lighthouse_client_new( addr: String, connect_timeout: Duration, ) -> Result> { - info!("LighthouseClient: establishing connection to {}", &addr); + info!( + "Creating LighthouseClient: establishing connection to {}", + &addr + ); let conn = connect(addr, connect_timeout).await?; Ok(LighthouseServiceClient::new(conn)) } @@ -135,11 +156,7 @@ impl Manager { } async fn _run_grpc(self: Arc) -> Result<()> { - info!( - "Manager {} listening on {}", - self.replica_id, - self.address() - ); + info_with_replica!(self.replica_id, "Manager listening on {}", self.address()); let listener = self.listener.lock().await.take().unwrap(); let incoming = @@ -176,7 +193,7 @@ impl Manager { } state.participants.clear(); - info!("all workers joined -- starting quorum"); + info_with_replica!(self.replica_id, "All workers joined - starting quorum"); // TODO: don't hold the lock during quorum @@ -197,7 +214,7 @@ impl Manager { })?; let resp = response.into_inner(); - info!("got lighthouse quorum {:?}", resp); + info_with_replica!(self.replica_id, "got lighthouse quorum {:?}", resp); state .channel @@ -220,7 +237,7 @@ impl ManagerService for Arc { let req = request.get_ref(); let rank = req.rank; - info!("got quorum request for rank {}", rank); + info_with_replica!(self.replica_id, "Start quorum for rank {}", rank); let timeout = try_parse_grpc_timeout(&request.metadata()) .map_err(|e| { @@ -266,7 +283,7 @@ impl ManagerService for Arc { .await .map_err(|e| Status::internal(e.to_string()))?; - info!("returning quorum for rank {}", rank); + info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank); let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?; @@ -299,9 +316,11 @@ impl ManagerService for Arc { let req = request.into_inner(); let rank = req.rank; - info!( + info_with_replica!( + self.replica_id, "should_commit request from {} should_commit={}", - rank, req.should_commit + rank, + req.should_commit ); // TODO: check step count @@ -318,7 +337,11 @@ impl ManagerService for Arc { if state.should_commit_count.len() == self.world_size as usize { let decision = state.should_commit_failures.len() == 0; - info!("should_commit completed should_commit={}", decision); + info_with_replica!( + self.replica_id, + "should_commit completed should_commit={}", + decision + ); state .should_commit_channel @@ -448,7 +471,8 @@ fn compute_quorum_results( let heal = recover_src_rank.is_some(); if heal { - info!( + info_with_replica!( + replica_id, "healing is required step={}, max_step={}, recover_src_rank={}", step, max_step, diff --git a/torchft/manager.py b/torchft/manager.py index 668189c..0da48d0 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -189,9 +189,13 @@ def __init__( bind = f"[::]:{port}" lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"] - if replica_id is None: - replica_id = "" - replica_id = replica_id + str(uuid.uuid4()) + # We need a unique identifier in the case that a worker restarts quickly and + # replaces the previous worker with the same ID. + new_uuid = str(uuid.uuid4()) + if replica_id is None or replica_id == "": + replica_id = new_uuid + else: + replica_id = f"{replica_id}:{new_uuid}" self._manager = ManagerServer( replica_id=replica_id, lighthouse_addr=lighthouse_addr,