Skip to content

Commit

Permalink
Improve log readability (#122)
Browse files Browse the repository at this point in the history
* Make logging easier to read

* lintrunner

* Add back uuid
  • Loading branch information
H-Huang authored Mar 7, 2025
1 parent 2ab329e commit 8dd6a09
Show file tree
Hide file tree
Showing 5 changed files with 124 additions and 35 deletions.
5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,15 @@ edition = "2021"
[dependencies]
anyhow = "1.0.89"
askama = "0.12.1"
atty = "0.2.14"
axum = "0.7.7"
chrono = "0.4.40"
fern = {version = "0.7.1", features = ["colored"]}
gethostname = "0.5.0"
log = "0.4.22"
prost = "0.13.3"
prost-types = "0.13.3"
pyo3 = {version="0.22.3", features = ["extension-module"]}
pyo3 = {version = "0.22.3", features = ["extension-module"]}
rand = "0.8.5"
slog = "2.7.0"
slog-stdlog = "4.1.1"
Expand Down
62 changes: 48 additions & 14 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,22 @@ mod net;
mod retry;
mod timeout;

use anyhow::Result;
use atty::Stream;
use core::time::Duration;
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
use std::env;
use std::sync::Arc;

use anyhow::Result;
use pyo3::exceptions::{PyRuntimeError, PyTimeoutError};
use structopt::StructOpt;
use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
use tonic::transport::Channel;
use tonic::Status;

use chrono::Local;
use fern::colors::{Color, ColoredLevelConfig};
use log::LevelFilter;

pub mod torchftpb {
tonic::include_proto!("torchft");
}
Expand Down Expand Up @@ -397,20 +401,50 @@ impl From<Status> for StatusError {
}
}

fn setup_logging() -> Result<(), Box<dyn std::error::Error>> {
// Check if stderr is a terminal
let is_terminal = atty::is(Stream::Stderr);
let colors = ColoredLevelConfig::new()
.error(Color::Red)
.warn(Color::Yellow)
.info(Color::Green)
.debug(Color::Blue)
.trace(Color::Magenta);
let level_filter = match env::var("RUST_LOG").as_deref() {
Ok("error") => LevelFilter::Error,
Ok("warn") => LevelFilter::Warn,
Ok("info") => LevelFilter::Info,
Ok("debug") => LevelFilter::Debug,
Ok("trace") => LevelFilter::Trace,
_ => LevelFilter::Info,
};
fern::Dispatch::new()
.format(move |out, message, record| {
let module_path = record.module_path().unwrap_or("<unknown>");
// If stderr is a terminal, use colors when printing log level, otherwise use plain text
let level = if is_terminal {
colors.color(record.level()).to_string()
} else {
record.level().to_string()
};
out.finish(format_args!(
"{} [{}] [{}] - {}",
Local::now().format("%Y-%m-%dT%H:%M:%S%.3f"),
level,
module_path,
message
))
})
.level(level_filter)
.chain(std::io::stderr())
.apply()?;
Ok(())
}

#[pymodule]
fn _torchft(m: &Bound<'_, PyModule>) -> PyResult<()> {
// setup logging on import
let mut log = stderrlog::new();
log.verbosity(2)
.show_module_names(true)
.timestamp(stderrlog::Timestamp::Millisecond);

if env::var("CLICOLOR_FORCE").is_ok() {
log.color(stderrlog::ColorChoice::AlwaysAnsi);
}

log.init()
.map_err(|e| PyRuntimeError::new_err(e.to_string()))?;
setup_logging().map_err(|e| PyRuntimeError::new_err(e.to_string()))?;

m.add_class::<ManagerServer>()?;
m.add_class::<ManagerClient>()?;
Expand Down
28 changes: 26 additions & 2 deletions src/lighthouse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use core::net::SocketAddr;
use std::collections::HashMap;
use std::collections::HashSet;
use std::sync::Arc;

use std::time::Duration;
use std::time::{Instant, SystemTime};

Expand Down Expand Up @@ -61,6 +62,25 @@ pub struct Lighthouse {
opt: LighthouseOpt,
listener: Mutex<Option<tokio::net::TcpListener>>,
local_addr: SocketAddr,
change_logger: ChangeLogger,
}

struct ChangeLogger {
last_reason: std::sync::Mutex<Option<String>>,
}
impl ChangeLogger {
fn new() -> Self {
ChangeLogger {
last_reason: std::sync::Mutex::new(None),
}
}
fn log_if_changed(&self, reason: &str) {
let mut last_reason = self.last_reason.lock().unwrap();
if last_reason.as_deref() != Some(reason) {
info!("Quorum status: {}", reason);
*last_reason = Some(reason.to_string());
}
}
}

#[derive(StructOpt, Debug)]
Expand Down Expand Up @@ -257,12 +277,13 @@ impl Lighthouse {
opt: opt,
local_addr: listener.local_addr()?,
listener: Mutex::new(Some(listener)),
change_logger: ChangeLogger::new(),
}))
}

fn _quorum_tick(self: Arc<Self>, state: &mut State) -> Result<()> {
let (quorum_met, reason) = quorum_compute(Instant::now(), state, &self.opt);
info!("Next quorum status: {}", reason);
self.change_logger.log_if_changed(&reason);

if quorum_met.is_some() {
let participants = quorum_met.unwrap();
Expand Down Expand Up @@ -448,7 +469,10 @@ impl LighthouseService for Arc<Lighthouse> {
.requester
.ok_or_else(|| return Status::invalid_argument("missing requester"))?;

info!("got quorum request for replica {}", &requester.replica_id);
info!(
"Received quorum request for replica {}",
&requester.replica_id
);

let mut rx = {
let mut state = self.state.lock().await;
Expand Down
54 changes: 39 additions & 15 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,21 @@ use log::{info, warn};
#[cfg(test)]
use std::{println as info, println as warn};

// The replica_id string is of the form {replica_name}:{uuid} or just {uuid} (see torchft/manager.py)
// We can parse the replica_id if it exists, otherwise we just use the uuid
macro_rules! info_with_replica {
($replica_id:expr, $($arg:tt)*) => {{
let parts: Vec<&str> = $replica_id.splitn(2, ':').collect();
let formatted_message = if parts.len() == 2 {
// If there are two parts, use the replica name
info!("[Replica {}] {}", parts[0], format!($($arg)*))
} else {
// Otherwise, just use the UUID
info!("[Replica {}] {}", $replica_id, format!($($arg)*))
};
}};
}

struct ManagerState {
checkpoint_metadata: HashMap<i64, String>,
channel: broadcast::Sender<Quorum>,
Expand All @@ -63,7 +78,10 @@ pub async fn manager_client_new(
addr: String,
connect_timeout: Duration,
) -> Result<ManagerServiceClient<Channel>> {
info!("ManagerClient: establishing connection to {}", &addr);
info!(
"Creating ManagerClient: establishing connection to {}",
&addr
);
let conn = connect(addr, connect_timeout).await?;
Ok(ManagerServiceClient::new(conn))
}
Expand All @@ -72,7 +90,10 @@ pub async fn lighthouse_client_new(
addr: String,
connect_timeout: Duration,
) -> Result<LighthouseServiceClient<Channel>> {
info!("LighthouseClient: establishing connection to {}", &addr);
info!(
"Creating LighthouseClient: establishing connection to {}",
&addr
);
let conn = connect(addr, connect_timeout).await?;
Ok(LighthouseServiceClient::new(conn))
}
Expand Down Expand Up @@ -135,11 +156,7 @@ impl Manager {
}

async fn _run_grpc(self: Arc<Self>) -> Result<()> {
info!(
"Manager {} listening on {}",
self.replica_id,
self.address()
);
info_with_replica!(self.replica_id, "Manager listening on {}", self.address());

let listener = self.listener.lock().await.take().unwrap();
let incoming =
Expand Down Expand Up @@ -176,7 +193,7 @@ impl Manager {
}

state.participants.clear();
info!("all workers joined -- starting quorum");
info_with_replica!(self.replica_id, "All workers joined - starting quorum");

// TODO: don't hold the lock during quorum

Expand All @@ -197,7 +214,7 @@ impl Manager {
})?;
let resp = response.into_inner();

info!("got lighthouse quorum {:?}", resp);
info_with_replica!(self.replica_id, "got lighthouse quorum {:?}", resp);

state
.channel
Expand All @@ -220,7 +237,7 @@ impl ManagerService for Arc<Manager> {
let req = request.get_ref();
let rank = req.rank;

info!("got quorum request for rank {}", rank);
info_with_replica!(self.replica_id, "Start quorum for rank {}", rank);

let timeout = try_parse_grpc_timeout(&request.metadata())
.map_err(|e| {
Expand Down Expand Up @@ -266,7 +283,7 @@ impl ManagerService for Arc<Manager> {
.await
.map_err(|e| Status::internal(e.to_string()))?;

info!("returning quorum for rank {}", rank);
info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank);

let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?;

Expand Down Expand Up @@ -299,9 +316,11 @@ impl ManagerService for Arc<Manager> {
let req = request.into_inner();
let rank = req.rank;

info!(
info_with_replica!(
self.replica_id,
"should_commit request from {} should_commit={}",
rank, req.should_commit
rank,
req.should_commit
);

// TODO: check step count
Expand All @@ -318,7 +337,11 @@ impl ManagerService for Arc<Manager> {

if state.should_commit_count.len() == self.world_size as usize {
let decision = state.should_commit_failures.len() == 0;
info!("should_commit completed should_commit={}", decision);
info_with_replica!(
self.replica_id,
"should_commit completed should_commit={}",
decision
);

state
.should_commit_channel
Expand Down Expand Up @@ -448,7 +471,8 @@ fn compute_quorum_results(

let heal = recover_src_rank.is_some();
if heal {
info!(
info_with_replica!(
replica_id,
"healing is required step={}, max_step={}, recover_src_rank={}",
step,
max_step,
Expand Down
10 changes: 7 additions & 3 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,9 +189,13 @@ def __init__(
bind = f"[::]:{port}"
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]

if replica_id is None:
replica_id = ""
replica_id = replica_id + str(uuid.uuid4())
# We need a unique identifier in the case that a worker restarts quickly and
# replaces the previous worker with the same ID.
new_uuid = str(uuid.uuid4())
if replica_id is None or replica_id == "":
replica_id = new_uuid
else:
replica_id = f"{replica_id}:{new_uuid}"
self._manager = ManagerServer(
replica_id=replica_id,
lighthouse_addr=lighthouse_addr,
Expand Down

0 comments on commit 8dd6a09

Please sign in to comment.