Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve log readability #122

Merged
merged 3 commits into from
Mar 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ edition = "2021"
[dependencies]
anyhow = "1.0.89"
askama = "0.12.1"
atty = "0.2.14"
axum = "0.7.7"
chrono = "0.4.40"
fern = {version = "0.7.1", features = ["colored"]}
gethostname = "0.5.0"
log = "0.4.22"
prost = "0.13.3"
prost-types = "0.13.3"
pyo3 = {version="0.22.3", features = ["extension-module"]}
pyo3 = {version = "0.22.3", features = ["extension-module"]}
rand = "0.8.5"
slog = "2.7.0"
slog-stdlog = "4.1.1"
Expand Down
62 changes: 48 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,22 @@ mod net;
mod retry;
mod timeout;

use anyhow::Result;
use atty::Stream;
use core::time::Duration;
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
use std::env;
use std::sync::Arc;

use anyhow::Result;
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
use structopt::StructOpt;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
use tonic::transport::Channel;
use tonic::Status;

use chrono::Local;
use fern::colors::{Color, ColoredLevelConfig};
use log::LevelFilter;

pub mod torchftpb {
tonic::include_proto!("torchft");
}
Expand Down Expand Up @@ -397,20 +401,50 @@ impl From<Status> for StatusError {
}
}

fn setup_logging() -> Result<(), Box<dyn std::error::Error>> {
// Check if stderr is a terminal
let is_terminal = atty::is(Stream::Stderr);
let colors = ColoredLevelConfig::new()
.error(Color::Red)
.warn(Color::Yellow)
.info(Color::Green)
.debug(Color::Blue)
.trace(Color::Magenta);
let level_filter = match env::var("RUST_LOG").as_deref() {
Ok("error") => LevelFilter::Error,
Ok("warn") => LevelFilter::Warn,
Ok("info") => LevelFilter::Info,
Ok("debug") => LevelFilter::Debug,
Ok("trace") => LevelFilter::Trace,
_ => LevelFilter::Info,
};
fern::Dispatch::new()
.format(move |out, message, record| {
let module_path = record.module_path().unwrap_or("<unknown>");
// If stderr is a terminal, use colors when printing log level, otherwise use plain text
let level = if is_terminal {
colors.color(record.level()).to_string()
} else {
record.level().to_string()
};
out.finish(format_args!(
"{} [{}] [{}] - {}",
Local::now().format("%Y-%m-%dT%H:%M:%S%.3f"),
level,
module_path,
message
))
})
.level(level_filter)
.chain(std::io::stderr())
.apply()?;
Ok(())
}

#[pymodule]
fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
// setup logging on import
let mut log = stderrlog::new();
log.verbosity(2)
.show_module_names(true)
.timestamp(stderrlog::Timestamp::Millisecond);

if env::var("CLICOLOR_FORCE").is_ok() {
log.color(stderrlog::ColorChoice::AlwaysAnsi);
}

log.init()
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
setup_logging().map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

m.add_class::<ManagerServer>()?;
m.add_class::<ManagerClient>()?;
Expand Down
28 changes: 26 additions & 2 deletions src/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use core::net::SocketAddr;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;

use std::time::Duration;
use std::time::{Instant, SystemTime};

Expand Down Expand Up @@ -61,6 +62,25 @@ pub struct Lighthouse {
opt: LighthouseOpt,
listener: Mutex<Option<tokio::net::TcpListener>>,
local_addr: SocketAddr,
change_logger: ChangeLogger,
}

struct ChangeLogger {
last_reason: std::sync::Mutex<Option<String>>,
}
impl ChangeLogger {
fn new() -> Self {
ChangeLogger {
last_reason: std::sync::Mutex::new(None),
}
}
fn log_if_changed(&self, reason: &str) {
let mut last_reason = self.last_reason.lock().unwrap();
if last_reason.as_deref() != Some(reason) {
info!("Quorum status: {}", reason);
*last_reason = Some(reason.to_string());
}
}
}

#[derive(StructOpt, Debug)]
Expand Down Expand Up @@ -257,12 +277,13 @@ impl Lighthouse {
opt: opt,
local_addr: listener.local_addr()?,
listener: Mutex::new(Some(listener)),
change_logger: ChangeLogger::new(),
}))
}

fn _quorum_tick(self: Arc<Self>, state: &mut State) -> Result<()> {
let (quorum_met, reason) = quorum_compute(Instant::now(), state, &self.opt);
info!("Next quorum status: {}", reason);
self.change_logger.log_if_changed(&reason);

if quorum_met.is_some() {
let participants = quorum_met.unwrap();
Expand Down Expand Up @@ -448,7 +469,10 @@ impl LighthouseService for Arc<Lighthouse> {
.requester
.ok_or_else(|| return Status::invalid_argument("missing requester"))?;

info!("got quorum request for replica {}", &requester.replica_id);
info!(
"Received quorum request for replica {}",
&requester.replica_id
);

let mut rx = {
let mut state = self.state.lock().await;
Expand Down
54 changes: 39 additions & 15 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ use log::{info, warn};
#[cfg(test)]
use std::{println as info, println as warn};

// The replica_id string is of the form {replica_name}:{uuid} or just {uuid} (see torchft/manager.py)
// We can parse the replica_id if it exists, otherwise we just use the uuid
macro_rules! info_with_replica {
($replica_id:expr, $($arg:tt)*) => {{
let parts: Vec<&str> = $replica_id.splitn(2, ':').collect();
let formatted_message = if parts.len() == 2 {
// If there are two parts, use the replica name
info!("[Replica {}] {}", parts[0], format!($($arg)*))
} else {
// Otherwise, just use the UUID
info!("[Replica {}] {}", $replica_id, format!($($arg)*))
};
}};
}

struct ManagerState {
checkpoint_metadata: HashMap<i64, String>,
channel: broadcast::Sender<Quorum>,
Expand All @@ -63,7 +78,10 @@ pub async fn manager_client_new(
addr: String,
connect_timeout: Duration,
) -> Result<ManagerServiceClient<Channel>> {
info!("ManagerClient: establishing connection to {}", &addr);
info!(
"Creating ManagerClient: establishing connection to {}",
&addr
);
let conn = connect(addr, connect_timeout).await?;
Ok(ManagerServiceClient::new(conn))
}
Expand All @@ -72,7 +90,10 @@ pub async fn lighthouse_client_new(
addr: String,
connect_timeout: Duration,
) -> Result<LighthouseServiceClient<Channel>> {
info!("LighthouseClient: establishing connection to {}", &addr);
info!(
"Creating LighthouseClient: establishing connection to {}",
&addr
);
let conn = connect(addr, connect_timeout).await?;
Ok(LighthouseServiceClient::new(conn))
}
Expand Down Expand Up @@ -135,11 +156,7 @@ impl Manager {
}

async fn _run_grpc(self: Arc<Self>) -> Result<()> {
info!(
"Manager {} listening on {}",
self.replica_id,
self.address()
);
info_with_replica!(self.replica_id, "Manager listening on {}", self.address());

let listener = self.listener.lock().await.take().unwrap();
let incoming =
Expand Down Expand Up @@ -176,7 +193,7 @@ impl Manager {
}

state.participants.clear();
info!("all workers joined -- starting quorum");
info_with_replica!(self.replica_id, "All workers joined - starting quorum");

// TODO: don't hold the lock during quorum

Expand All @@ -197,7 +214,7 @@ impl Manager {
})?;
let resp = response.into_inner();

info!("got lighthouse quorum {:?}", resp);
info_with_replica!(self.replica_id, "got lighthouse quorum {:?}", resp);

state
.channel
Expand All @@ -220,7 +237,7 @@ impl ManagerService for Arc<Manager> {
let req = request.get_ref();
let rank = req.rank;

info!("got quorum request for rank {}", rank);
info_with_replica!(self.replica_id, "Start quorum for rank {}", rank);

let timeout = try_parse_grpc_timeout(&request.metadata())
.map_err(|e| {
Expand Down Expand Up @@ -266,7 +283,7 @@ impl ManagerService for Arc<Manager> {
.await
.map_err(|e| Status::internal(e.to_string()))?;

info!("returning quorum for rank {}", rank);
info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank);

let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?;

Expand Down Expand Up @@ -299,9 +316,11 @@ impl ManagerService for Arc<Manager> {
let req = request.into_inner();
let rank = req.rank;

info!(
info_with_replica!(
self.replica_id,
"should_commit request from {} should_commit={}",
rank, req.should_commit
rank,
req.should_commit
);

// TODO: check step count
Expand All @@ -318,7 +337,11 @@ impl ManagerService for Arc<Manager> {

if state.should_commit_count.len() == self.world_size as usize {
let decision = state.should_commit_failures.len() == 0;
info!("should_commit completed should_commit={}", decision);
info_with_replica!(
self.replica_id,
"should_commit completed should_commit={}",
decision
);

state
.should_commit_channel
Expand Down Expand Up @@ -448,7 +471,8 @@ fn compute_quorum_results(

let heal = recover_src_rank.is_some();
if heal {
info!(
info_with_replica!(
replica_id,
"healing is required step={}, max_step={}, recover_src_rank={}",
step,
max_step,
Expand Down
10 changes: 7 additions & 3 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,13 @@ def __init__(
bind = f"[::]:{port}"
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]

if replica_id is None:
replica_id = ""
replica_id = replica_id + str(uuid.uuid4())
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wasn't sure why uuid was added to the replica_id, is this needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ye, I was wondering the same thing. TorchTitan has to keep the original replicate_id because this change every time due to uuid4.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do this as a simplification for the quorum algorithm -- if a worker restarts quickly with the same name we may not detect that the quorum has changed and thus no-reconfiguration of the PGs will occur. Adding a random UUID to the worker means that we will always detect processes restart and thus correctly trigger a reconfiguration

@H-Huang can you add this back and if it's convenient also add a comment explaining why this is the case? Also it would be nice to add a : to it so replica_01234-1234 shows up as replica_0:1234-1234 . I've been meaning to fix this but haven't had a chance

# We need a unique identifier in the case that a worker restarts quickly and
# replaces the previous worker with the same ID.
new_uuid = str(uuid.uuid4())
if replica_id is None or replica_id == "":
replica_id = new_uuid
else:
replica_id = f"{replica_id}:{new_uuid}"
self._manager = ManagerServer(
replica_id=replica_id,
lighthouse_addr=lighthouse_addr,
Expand Down