Skip to content

manager: Add option to skip initial sync #159

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Apr 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions proto/torchft.proto
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ message ManagerQuorumRequest {
int64 step = 2;
string checkpoint_metadata = 3;
bool shrink_only = 4;
bool init_sync = 5;
}

message ManagerQuorumResponse {
Expand Down
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ impl ManagerClient {
step: i64,
checkpoint_metadata: String,
shrink_only: bool,
init_sync: bool,
timeout: Duration,
) -> Result<QuorumResult, StatusError> {
py.allow_threads(move || {
Expand All @@ -183,6 +184,7 @@ impl ManagerClient {
step: step,
checkpoint_metadata: checkpoint_metadata,
shrink_only: shrink_only,
init_sync: init_sync,
});

// This timeout is processed on the server side so we also enable
Expand Down
74 changes: 63 additions & 11 deletions src/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ impl ManagerService for Arc<Manager> {

info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank);

let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?;
let reply = compute_quorum_results(&self.replica_id, rank, &quorum, req.init_sync)?;

Ok(Response::new(reply))
}
Expand Down Expand Up @@ -382,6 +382,7 @@ fn compute_quorum_results(
replica_id: &str,
rank: i64,
quorum: &Quorum,
init_sync: bool,
) -> Result<ManagerQuorumResponse, Status> {
let mut participants = quorum.participants.clone();
participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id));
Expand Down Expand Up @@ -424,20 +425,23 @@ fn compute_quorum_results(

// Compute recovery assignments

// Nodes are recovering if:
// 1. not at the max step
let force_recover = init_sync && max_step == 0;

// Nodes are recovering if
// 1. not at the max step (init_sync)
// 2. max_step == 0 and not the primary replica
let all_recover_dst_ranks: Vec<usize> = participants
.iter()
.enumerate()
.filter_map(|(i, p)| {
if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id {
if p.step != max_step || force_recover && primary.replica_id != p.replica_id {
Some(i)
} else {
None
}
})
.collect();

let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::<HashSet<_>>();
let up_to_date_ranks: Vec<usize> = participants
.iter()
Expand Down Expand Up @@ -605,6 +609,7 @@ mod tests {
step: 123,
checkpoint_metadata: "addr".to_string(),
shrink_only: false,
init_sync: true,
});
request.set_timeout(Duration::from_secs(10));
let resp = client.quorum(request).await?.into_inner();
Expand Down Expand Up @@ -664,6 +669,7 @@ mod tests {
step: 0,
checkpoint_metadata: "addr".to_string(),
shrink_only: false,
init_sync: true,
});
request.set_timeout(Duration::from_secs(10));

Expand Down Expand Up @@ -771,21 +777,21 @@ mod tests {

// rank 0

let results = compute_quorum_results("replica_0", 0, &quorum)?;
let results = compute_quorum_results("replica_0", 0, &quorum, true)?;
assert!(!results.heal);
assert_eq!(results.replica_rank, 0);
assert_eq!(results.recover_src_rank, None);
assert_eq!(results.recover_dst_ranks, vec![1]);

let results = compute_quorum_results("replica_1", 0, &quorum)?;
let results = compute_quorum_results("replica_1", 0, &quorum, true)?;
assert!(results.heal);
assert_eq!(results.replica_rank, 1);
assert_eq!(results.recover_src_rank, Some(0));
assert_eq!(results.recover_dst_ranks, Vec::<i64>::new());

// rank 1 assignments should be offset from rank 0 above and the primary

let results = compute_quorum_results("replica_1", 1, &quorum)?;
let results = compute_quorum_results("replica_1", 1, &quorum, true)?;
assert!(!results.heal);
assert_eq!(results.replica_rank, 1);
assert_eq!(results.recover_src_rank, None);
Expand Down Expand Up @@ -850,34 +856,80 @@ mod tests {

// rank 0

let results = compute_quorum_results("replica_0", 0, &quorum)?;
let results = compute_quorum_results("replica_0", 0, &quorum, true)?;
assert!(results.heal);
assert_eq!(results.recover_src_manager_address, "addr_1".to_string());
assert_eq!(results.replica_rank, 0);
assert_eq!(results.recover_src_rank, Some(1));
assert!(results.recover_dst_ranks.is_empty());

let results = compute_quorum_results("replica_1", 0, &quorum)?;
let results = compute_quorum_results("replica_1", 0, &quorum, true)?;
assert!(!results.heal);
assert_eq!(results.recover_src_manager_address, "".to_string());
assert_eq!(results.replica_rank, 1);
assert_eq!(results.recover_src_rank, None);
assert_eq!(results.recover_dst_ranks, vec![0, 4]);

let results = compute_quorum_results("replica_3", 0, &quorum)?;
let results = compute_quorum_results("replica_3", 0, &quorum, true)?;
assert!(!results.heal);
assert_eq!(results.replica_rank, 3);
assert_eq!(results.recover_src_rank, None);
assert_eq!(results.recover_dst_ranks, vec![2]);

// rank 1 assignments should be offset from rank 0 above

let results = compute_quorum_results("replica_1", 1, &quorum)?;
let results = compute_quorum_results("replica_1", 1, &quorum, true)?;
assert!(!results.heal);
assert_eq!(results.replica_rank, 1);
assert_eq!(results.recover_src_rank, None);
assert_eq!(results.recover_dst_ranks, vec![2]);

Ok(())
}

#[tokio::test]
async fn test_compute_quorum_results_skip_init_sync() -> Result<()> {
let mut quorum = Quorum {
quorum_id: 1,
participants: vec![
QuorumMember {
replica_id: "replica_0".to_string(),
address: "addr_0".to_string(),
store_address: "store_addr_0".to_string(),
step: 0,
world_size: 1,
shrink_only: false,
data: String::new(),
},
QuorumMember {
replica_id: "replica_1".to_string(),
address: "addr_1".to_string(),
store_address: "store_addr_1".to_string(),
step: 0,
world_size: 1,
shrink_only: false,
data: String::new(),
},
],
created: None,
};

// baseline w/ init_sync=true
let results = compute_quorum_results("replica_0", 0, &quorum, true)?;
assert!(!results.heal);

let results = compute_quorum_results("replica_1", 0, &quorum, true)?;
assert!(results.heal);

// init_sync=false
let results = compute_quorum_results("replica_1", 0, &quorum, false)?;
assert!(!results.heal);

// init_sync=false, step=1
quorum.participants[0].step = 1;
let results = compute_quorum_results("replica_1", 0, &quorum, false)?;
assert!(results.heal);

Ok(())
}
}
1 change: 1 addition & 0 deletions torchft/_torchft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class ManagerClient:
checkpoint_metadata: str,
shrink_only: bool,
timeout: timedelta,
init_sync: bool = True,
) -> QuorumResult: ...
def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ...
def should_commit(
Expand Down
6 changes: 6 additions & 0 deletions torchft/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def __init__(
hostname: str = socket.gethostname(),
heartbeat_interval: timedelta = timedelta(milliseconds=100),
checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None,
init_sync: bool = True,
) -> None:
"""
Args:
Expand Down Expand Up @@ -143,6 +144,9 @@ def __init__(
hostname: if rank==0, the hostname to advertise to the lighthouse server
checkpoint_transport: the checkpoint transport to use for
transfering checkpoints to recovering replicas, defaults to HTTPTransport
init_sync: whether to synchronize the model weights on step 0. If
all of the model weights are initialized identically via
``torch.set_seed`` you should set this to False.
"""
self._load_state_dict = load_state_dict
self._user_state_dict = state_dict
Expand All @@ -152,6 +156,7 @@ def __init__(
self._quorum_timeout = quorum_timeout
self._connect_timeout = connect_timeout
self._world_size_mode = world_size_mode
self._init_sync = init_sync

store_addr = store_addr or os.environ["MASTER_ADDR"]
store_port = store_port or int(os.environ["MASTER_PORT"])
Expand Down Expand Up @@ -455,6 +460,7 @@ def _async_quorum(
checkpoint_metadata=self._checkpoint_transport.metadata(),
shrink_only=shrink_only,
timeout=quorum_timeout,
init_sync=self._init_sync,
)

quorum_id = quorum.quorum_id
Expand Down
69 changes: 66 additions & 3 deletions torchft/manager_integ_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

logger: logging.Logger = logging.getLogger(__name__)

INIT_LOCK: threading.Lock = threading.Lock()


class MyModel(nn.Module):
def __init__(self, in_dim: int = 3, out_dim: int = 4) -> None:
Expand Down Expand Up @@ -191,7 +193,13 @@ def state_dict() -> Dict[str, Dict[str, object]]:
)
stack.callback(lambda: manager.shutdown(wait=False))

m: nn.Module = DistributedDataParallel(manager, MyModel())
with INIT_LOCK:
# We need to lock during init for testing init_sync=False as all
# threads share the same RNG
torch.manual_seed(42)
m: nn.Module = MyModel()

m: nn.Module = DistributedDataParallel(manager, m)
optimizer: optim.Optimizer = OptimizerWrapper(
manager, optim.Adam(m.parameters())
)
Expand Down Expand Up @@ -270,7 +278,11 @@ def test_ddp_healthy(self) -> None:
),
]
)
def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None:
def test_ddp_recovery(
self,
name: str,
use_async_quorum: bool,
) -> None:
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=2,
Expand Down Expand Up @@ -302,7 +314,11 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None:
state_dicts = []

for fut in as_completed(futures):
state_dicts.append(fut.result())
try:
state_dicts.append(fut.result())
except Exception as e:
print(e)
raise

lighthouse.shutdown()

Expand All @@ -311,6 +327,53 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None:

self.assertEqual(failure_injectors[1].count, 1)

def test_ddp_skip_init_sync(
self,
) -> None:
lighthouse = LighthouseServer(
bind="[::]:0",
min_replicas=2,
)
num_replicas = 2
futures = []

# no failures
failure_injectors = [
FailureInjector(),
FailureInjector(),
]

with ThreadPoolExecutor(max_workers=num_replicas) as executor:
for replica_id, failure_injector in zip(
range(num_replicas), failure_injectors
):
runner = Runner(
replica_id=replica_id,
num_replicas=num_replicas,
lighthouse_address=lighthouse.address(),
failure_injector=failure_injector,
manager_args={
"use_async_quorum": False,
"init_sync": False,
},
train_loop=ddp_train_loop,
)
futures.append(executor.submit(runner.run_replica))

state_dicts = []

for fut in as_completed(futures):
try:
state_dicts.append(fut.result())
except Exception as e:
print(e)
raise

lighthouse.shutdown()

for state_dict in state_dicts:
torch.testing.assert_close(state_dict, state_dicts[0])

def test_ddp_recovery_multi_rank(self) -> None:
lighthouse = LighthouseServer(
bind="[::]:0",
Expand Down
31 changes: 31 additions & 0 deletions torchft/manager_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def _create_manager(
min_replica_size: int = 2,
world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
timeout: timedelta = timedelta(seconds=10),
init_sync: bool = True,
) -> Manager:
pg = create_autospec(ProcessGroup)
pg.errored.return_value = None
Expand Down Expand Up @@ -67,6 +68,7 @@ def _create_manager(
use_async_quorum=use_async_quorum,
world_size_mode=world_size_mode,
timeout=timeout,
init_sync=init_sync,
)
self.manager = manager
return manager
Expand Down Expand Up @@ -614,3 +616,32 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None:
client_mock().should_commit.call_args.kwargs["timeout"],
timedelta(seconds=23),
)

@patch("torchft.manager.ManagerClient", autospec=True)
def test_quorum_skip_init(self, client_mock: MagicMock) -> None:
manager = self._create_manager(
use_async_quorum=False,
init_sync=False,
)

self.assertFalse(manager._init_sync)

quorum = QuorumResult()
quorum.quorum_id = 123
quorum.replica_rank = 1
quorum.replica_world_size = 2
quorum.recover_src_manager_address = "manager address"
quorum.store_address = f"localhost:{self.store.port}"
quorum.max_step = 1
quorum.max_rank = 1
quorum.max_world_size = 2
quorum.heal = False

client_mock()._quorum.return_value = quorum

manager.start_quorum()
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], False)

manager._init_sync = True
manager.start_quorum()
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)