From 9ef0785e56f3299b0a65c4034cbe18cacbc4b863 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Sat, 8 Mar 2025 17:03:57 -0800 Subject: [PATCH 1/6] add option to skip initial sync in Manager Summary: We currently always heal on step 0 to avoid synchronization issues. We want an option to support skipping this sync for users who set the PyTorch seed so all ranks are initialized with the same values. This diff added a init_sync boolean flag that can be passed from the manager client in python to the manager service in rust. If the manager service skips the sync depending on whether the init_sync is true. Test Plan: Reviewers: Subscribers: Tasks: Tags: --- proto/torchft.proto | 1 + src/manager.rs | 5 ++++- torchft/_torchft.pyi | 1 + torchft/manager.py | 3 ++- 4 files changed, 8 insertions(+), 2 deletions(-) diff --git a/proto/torchft.proto b/proto/torchft.proto index c4d0a81..7ffcaec 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -76,6 +76,7 @@ message ManagerQuorumRequest { int64 step = 2; string checkpoint_metadata = 3; bool shrink_only = 4; + optional bool init_sync = 5; } message ManagerQuorumResponse { diff --git a/src/manager.rs b/src/manager.rs index 08d0cc2..bc9654f 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -431,7 +431,10 @@ fn compute_quorum_results( .iter() .enumerate() .filter_map(|(i, p)| { - if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { + if init_sync + || p.step != max_step + || max_step == 0 && primary.replica_id != p.replica_id + { Some(i) } else { None diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index 49bdcdd..1a99913 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -11,6 +11,7 @@ class ManagerClient: checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, + init_sync: Optional[bool] = False, ) -> QuorumResult: ... def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( diff --git a/torchft/manager.py b/torchft/manager.py index 0697bd4..426fb35 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -34,7 +34,7 @@ from contextlib import nullcontext from datetime import timedelta from enum import Enum -from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast +from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar import torch from torch.distributed import ReduceOp, TCPStore @@ -455,6 +455,7 @@ def _async_quorum( checkpoint_metadata=self._checkpoint_transport.metadata(), shrink_only=shrink_only, timeout=quorum_timeout, + init_sync=self.init_sync, ) quorum_id = quorum.quorum_id From e768c0ab0ae7afe8ec45de1903ea01e62c8d4f18 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Sat, 8 Mar 2025 22:14:10 -0700 Subject: [PATCH 2/6] Fixed some rust files and tests Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- src/lib.rs | 1 + src/manager.rs | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index a29e00d..00254e2 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -183,6 +183,7 @@ impl ManagerClient { step: step, checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, + init_sync: Some(false), }); // This timeout is processed on the server side so we also enable diff --git a/src/manager.rs b/src/manager.rs index bc9654f..48db5a9 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -286,7 +286,12 @@ impl ManagerService for Arc { info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank); - let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?; + let reply = compute_quorum_results( + &self.replica_id, + rank, + &quorum, + req.init_sync.unwrap_or_default(), + )?; Ok(Response::new(reply)) } @@ -382,6 +387,7 @@ fn compute_quorum_results( replica_id: &str, rank: i64, quorum: &Quorum, + init_sync: bool, ) -> Result { let mut participants = quorum.participants.clone(); participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id)); @@ -608,6 +614,7 @@ mod tests { step: 123, checkpoint_metadata: "addr".to_string(), shrink_only: false, + init_sync: Some(false), }); request.set_timeout(Duration::from_secs(10)); let resp = client.quorum(request).await?.into_inner(); @@ -667,6 +674,7 @@ mod tests { step: 0, checkpoint_metadata: "addr".to_string(), shrink_only: false, + init_sync: Some(false), }); request.set_timeout(Duration::from_secs(10)); @@ -774,13 +782,13 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum)?; + let results = compute_quorum_results("replica_0", 0, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![1]); - let results = compute_quorum_results("replica_1", 0, &quorum)?; + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; assert!(results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, Some(0)); @@ -788,7 +796,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above and the primary - let results = compute_quorum_results("replica_1", 1, &quorum)?; + let results = compute_quorum_results("replica_1", 1, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); @@ -853,21 +861,21 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum)?; + let results = compute_quorum_results("replica_0", 0, &quorum, false)?; assert!(results.heal); assert_eq!(results.recover_src_manager_address, "addr_1".to_string()); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, Some(1)); assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 0, &quorum)?; + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; assert!(!results.heal); assert_eq!(results.recover_src_manager_address, "".to_string()); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![0, 4]); - let results = compute_quorum_results("replica_3", 0, &quorum)?; + let results = compute_quorum_results("replica_3", 0, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 3); assert_eq!(results.recover_src_rank, None); @@ -875,7 +883,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above - let results = compute_quorum_results("replica_1", 1, &quorum)?; + let results = compute_quorum_results("replica_1", 1, &quorum, false)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); From f7794a2749a5f0882310c079a66289986ad1321d Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Mon, 10 Mar 2025 07:42:45 -0700 Subject: [PATCH 3/6] Fix init_sync logic Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- proto/torchft.proto | 2 +- src/lib.rs | 2 +- src/manager.rs | 61 +++++++++++++++++++++----------------------- torchft/_torchft.pyi | 2 +- 4 files changed, 32 insertions(+), 35 deletions(-) diff --git a/proto/torchft.proto b/proto/torchft.proto index 7ffcaec..15a96d0 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -76,7 +76,7 @@ message ManagerQuorumRequest { int64 step = 2; string checkpoint_metadata = 3; bool shrink_only = 4; - optional bool init_sync = 5; + bool init_sync = 5; } message ManagerQuorumResponse { diff --git a/src/lib.rs b/src/lib.rs index 00254e2..68c81bf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -183,7 +183,7 @@ impl ManagerClient { step: step, checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, - init_sync: Some(false), + init_sync: true, }); // This timeout is processed on the server side so we also enable diff --git a/src/manager.rs b/src/manager.rs index 48db5a9..7b31794 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -286,12 +286,7 @@ impl ManagerService for Arc { info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank); - let reply = compute_quorum_results( - &self.replica_id, - rank, - &quorum, - req.init_sync.unwrap_or_default(), - )?; + let reply = compute_quorum_results(&self.replica_id, rank, &quorum, req.init_sync)?; Ok(Response::new(reply)) } @@ -430,23 +425,25 @@ fn compute_quorum_results( // Compute recovery assignments - // Nodes are recovering if: - // 1. not at the max step - // 2. max_step == 0 and not the primary replica - let all_recover_dst_ranks: Vec = participants - .iter() - .enumerate() - .filter_map(|(i, p)| { - if init_sync - || p.step != max_step - || max_step == 0 && primary.replica_id != p.replica_id - { - Some(i) - } else { - None - } - }) - .collect(); + let all_recover_dst_ranks = if init_sync { + // Nodes are recovering if + // 1. not at the max step + // 2. max_step == 0 and not the primary replica + participants + .iter() + .enumerate() + .filter_map(|(i, p)| { + if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { + Some(i) + } else { + None + } + }) + .collect() + } else { + Vec::::new() + }; + let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::>(); let up_to_date_ranks: Vec = participants .iter() @@ -614,7 +611,7 @@ mod tests { step: 123, checkpoint_metadata: "addr".to_string(), shrink_only: false, - init_sync: Some(false), + init_sync: true, }); request.set_timeout(Duration::from_secs(10)); let resp = client.quorum(request).await?.into_inner(); @@ -674,7 +671,7 @@ mod tests { step: 0, checkpoint_metadata: "addr".to_string(), shrink_only: false, - init_sync: Some(false), + init_sync: true, }); request.set_timeout(Duration::from_secs(10)); @@ -782,13 +779,13 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![1]); - let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, Some(0)); @@ -796,7 +793,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above and the primary - let results = compute_quorum_results("replica_1", 1, &quorum, false)?; + let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); @@ -861,21 +858,21 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(results.heal); assert_eq!(results.recover_src_manager_address, "addr_1".to_string()); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, Some(1)); assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.recover_src_manager_address, "".to_string()); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![0, 4]); - let results = compute_quorum_results("replica_3", 0, &quorum, false)?; + let results = compute_quorum_results("replica_3", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 3); assert_eq!(results.recover_src_rank, None); @@ -883,7 +880,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above - let results = compute_quorum_results("replica_1", 1, &quorum, false)?; + let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index 1a99913..fdbd1fa 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -11,7 +11,7 @@ class ManagerClient: checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, - init_sync: Optional[bool] = False, + init_sync: bool = True, ) -> QuorumResult: ... def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( From ab05ae7372f97235d7b6c240f887f35c8bf48b3a Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Tue, 11 Mar 2025 09:23:21 -0700 Subject: [PATCH 4/6] Add tests for manager.rs Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- src/manager.rs | 82 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 82 insertions(+) diff --git a/src/manager.rs b/src/manager.rs index 7b31794..519e7ca 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -888,4 +888,86 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_compute_quorum_results_skip_init_sync() -> Result<()> { + let quorum = Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "replica_0".to_string(), + address: "addr_0".to_string(), + store_address: "store_addr_0".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_1".to_string(), + address: "addr_1".to_string(), + store_address: "store_addr_1".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_2".to_string(), + address: "addr_2".to_string(), + store_address: "store_addr_2".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_3".to_string(), + address: "addr_3".to_string(), + store_address: "store_addr_3".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_4".to_string(), + address: "addr_4".to_string(), + store_address: "store_addr_4".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + ], + created: None, + }; + + // rank 0 + + let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 0); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_3", 0, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 3); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_1", 1, &quorum, false)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert!(results.recover_dst_ranks.is_empty()); + + Ok(()) + } } From c5940880a589e4b6bad22b6d2c278cf1a59d9eb5 Mon Sep 17 00:00:00 2001 From: Dave Lei Date: Wed, 12 Mar 2025 11:48:37 -0700 Subject: [PATCH 5/6] Added skip init_sync tests to python client Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags: --- torchft/manager_test.py | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/torchft/manager_test.py b/torchft/manager_test.py index fb13496..954b88a 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -614,3 +614,35 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: client_mock().should_commit.call_args.kwargs["timeout"], timedelta(seconds=23), ) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_quorum_skip_init(self, client_mock: MagicMock) -> None: + manager = self._create_manager(use_async_quorum=False) + + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock()._quorum.return_value = quorum + + manager.start_quorum() + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], True + ) + + manager.start_quorum(init_sync=True) + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], True + ) + + manager.start_quorum(init_sync=False) + self.assertEqual( + client_mock()._quorum.call_args.kwargs["init_sync"], False + ) From e06c6f278263604213246b255d52ea1b9e2deba0 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Tue, 8 Apr 2025 14:31:08 -0700 Subject: [PATCH 6/6] manager: final changes for init_sync --- src/lib.rs | 3 +- src/manager.rs | 96 +++++++++++------------------------ torchft/manager.py | 9 +++- torchft/manager_integ_test.py | 69 +++++++++++++++++++++++-- torchft/manager_test.py | 25 +++++---- 5 files changed, 116 insertions(+), 86 deletions(-) diff --git a/src/lib.rs b/src/lib.rs index 68c81bf..5ef1bcf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -175,6 +175,7 @@ impl ManagerClient { step: i64, checkpoint_metadata: String, shrink_only: bool, + init_sync: bool, timeout: Duration, ) -> Result { py.allow_threads(move || { @@ -183,7 +184,7 @@ impl ManagerClient { step: step, checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, - init_sync: true, + init_sync: init_sync, }); // This timeout is processed on the server side so we also enable diff --git a/src/manager.rs b/src/manager.rs index 519e7ca..358ff7a 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -425,24 +425,22 @@ fn compute_quorum_results( // Compute recovery assignments - let all_recover_dst_ranks = if init_sync { - // Nodes are recovering if - // 1. not at the max step - // 2. max_step == 0 and not the primary replica - participants - .iter() - .enumerate() - .filter_map(|(i, p)| { - if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { - Some(i) - } else { - None - } - }) - .collect() - } else { - Vec::::new() - }; + let force_recover = init_sync && max_step == 0; + + // Nodes are recovering if + // 1. not at the max step (init_sync) + // 2. max_step == 0 and not the primary replica + let all_recover_dst_ranks: Vec = participants + .iter() + .enumerate() + .filter_map(|(i, p)| { + if p.step != max_step || force_recover && primary.replica_id != p.replica_id { + Some(i) + } else { + None + } + }) + .collect(); let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::>(); let up_to_date_ranks: Vec = participants @@ -891,7 +889,7 @@ mod tests { #[tokio::test] async fn test_compute_quorum_results_skip_init_sync() -> Result<()> { - let quorum = Quorum { + let mut quorum = Quorum { quorum_id: 1, participants: vec![ QuorumMember { @@ -901,72 +899,36 @@ mod tests { step: 0, world_size: 1, shrink_only: false, + data: String::new(), }, QuorumMember { replica_id: "replica_1".to_string(), address: "addr_1".to_string(), store_address: "store_addr_1".to_string(), - step: 1, - world_size: 1, - shrink_only: false, - }, - QuorumMember { - replica_id: "replica_2".to_string(), - address: "addr_2".to_string(), - store_address: "store_addr_2".to_string(), - step: 0, - world_size: 1, - shrink_only: false, - }, - QuorumMember { - replica_id: "replica_3".to_string(), - address: "addr_3".to_string(), - store_address: "store_addr_3".to_string(), - step: 1, - world_size: 1, - shrink_only: false, - }, - QuorumMember { - replica_id: "replica_4".to_string(), - address: "addr_4".to_string(), - store_address: "store_addr_4".to_string(), step: 0, world_size: 1, shrink_only: false, + data: String::new(), }, ], created: None, }; - // rank 0 - - let results = compute_quorum_results("replica_0", 0, &quorum, false)?; + // baseline w/ init_sync=true + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(!results.heal); - assert_eq!(results.recover_src_manager_address, "".to_string()); - assert_eq!(results.replica_rank, 0); - assert_eq!(results.recover_src_rank, None); - assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 0, &quorum, false)?; - assert!(!results.heal); - assert_eq!(results.recover_src_manager_address, "".to_string()); - assert_eq!(results.replica_rank, 1); - assert_eq!(results.recover_src_rank, None); - assert!(results.recover_dst_ranks.is_empty()); + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; + assert!(results.heal); - let results = compute_quorum_results("replica_3", 0, &quorum, false)?; + // init_sync=false + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; assert!(!results.heal); - assert_eq!(results.recover_src_manager_address, "".to_string()); - assert_eq!(results.replica_rank, 3); - assert_eq!(results.recover_src_rank, None); - assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 1, &quorum, false)?; - assert!(!results.heal); - assert_eq!(results.recover_src_manager_address, "".to_string()); - assert_eq!(results.replica_rank, 1); - assert_eq!(results.recover_src_rank, None); - assert!(results.recover_dst_ranks.is_empty()); + // init_sync=false, step=1 + quorum.participants[0].step = 1; + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + assert!(results.heal); Ok(()) } diff --git a/torchft/manager.py b/torchft/manager.py index 426fb35..fa2760d 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -34,7 +34,7 @@ from contextlib import nullcontext from datetime import timedelta from enum import Enum -from typing import Callable, cast, Dict, List, Optional, TYPE_CHECKING, TypeVar +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, TypeVar, cast import torch from torch.distributed import ReduceOp, TCPStore @@ -106,6 +106,7 @@ def __init__( hostname: str = socket.gethostname(), heartbeat_interval: timedelta = timedelta(milliseconds=100), checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None, + init_sync: bool = True, ) -> None: """ Args: @@ -143,6 +144,9 @@ def __init__( hostname: if rank==0, the hostname to advertise to the lighthouse server checkpoint_transport: the checkpoint transport to use for transfering checkpoints to recovering replicas, defaults to HTTPTransport + init_sync: whether to synchronize the model weights on step 0. If + all of the model weights are initialized identically via + ``torch.set_seed`` you should set this to False. """ self._load_state_dict = load_state_dict self._user_state_dict = state_dict @@ -152,6 +156,7 @@ def __init__( self._quorum_timeout = quorum_timeout self._connect_timeout = connect_timeout self._world_size_mode = world_size_mode + self._init_sync = init_sync store_addr = store_addr or os.environ["MASTER_ADDR"] store_port = store_port or int(os.environ["MASTER_PORT"]) @@ -455,7 +460,7 @@ def _async_quorum( checkpoint_metadata=self._checkpoint_transport.metadata(), shrink_only=shrink_only, timeout=quorum_timeout, - init_sync=self.init_sync, + init_sync=self._init_sync, ) quorum_id = quorum.quorum_id diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index d591d0d..e7622be 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -25,6 +25,8 @@ logger: logging.Logger = logging.getLogger(__name__) +INIT_LOCK: threading.Lock = threading.Lock() + class MyModel(nn.Module): def __init__(self, in_dim: int = 3, out_dim: int = 4) -> None: @@ -191,7 +193,13 @@ def state_dict() -> Dict[str, Dict[str, object]]: ) stack.callback(lambda: manager.shutdown(wait=False)) - m: nn.Module = DistributedDataParallel(manager, MyModel()) + with INIT_LOCK: + # We need to lock during init for testing init_sync=False as all + # threads share the same RNG + torch.manual_seed(42) + m: nn.Module = MyModel() + + m: nn.Module = DistributedDataParallel(manager, m) optimizer: optim.Optimizer = OptimizerWrapper( manager, optim.Adam(m.parameters()) ) @@ -270,7 +278,11 @@ def test_ddp_healthy(self) -> None: ), ] ) - def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: + def test_ddp_recovery( + self, + name: str, + use_async_quorum: bool, + ) -> None: lighthouse = LighthouseServer( bind="[::]:0", min_replicas=2, @@ -302,7 +314,11 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: state_dicts = [] for fut in as_completed(futures): - state_dicts.append(fut.result()) + try: + state_dicts.append(fut.result()) + except Exception as e: + print(e) + raise lighthouse.shutdown() @@ -311,6 +327,53 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: self.assertEqual(failure_injectors[1].count, 1) + def test_ddp_skip_init_sync( + self, + ) -> None: + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + futures = [] + + # no failures + failure_injectors = [ + FailureInjector(), + FailureInjector(), + ] + + with ThreadPoolExecutor(max_workers=num_replicas) as executor: + for replica_id, failure_injector in zip( + range(num_replicas), failure_injectors + ): + runner = Runner( + replica_id=replica_id, + num_replicas=num_replicas, + lighthouse_address=lighthouse.address(), + failure_injector=failure_injector, + manager_args={ + "use_async_quorum": False, + "init_sync": False, + }, + train_loop=ddp_train_loop, + ) + futures.append(executor.submit(runner.run_replica)) + + state_dicts = [] + + for fut in as_completed(futures): + try: + state_dicts.append(fut.result()) + except Exception as e: + print(e) + raise + + lighthouse.shutdown() + + for state_dict in state_dicts: + torch.testing.assert_close(state_dict, state_dicts[0]) + def test_ddp_recovery_multi_rank(self) -> None: lighthouse = LighthouseServer( bind="[::]:0", diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 954b88a..2d421e6 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -40,6 +40,7 @@ def _create_manager( min_replica_size: int = 2, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, timeout: timedelta = timedelta(seconds=10), + init_sync: bool = True, ) -> Manager: pg = create_autospec(ProcessGroup) pg.errored.return_value = None @@ -67,6 +68,7 @@ def _create_manager( use_async_quorum=use_async_quorum, world_size_mode=world_size_mode, timeout=timeout, + init_sync=init_sync, ) self.manager = manager return manager @@ -617,7 +619,12 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: @patch("torchft.manager.ManagerClient", autospec=True) def test_quorum_skip_init(self, client_mock: MagicMock) -> None: - manager = self._create_manager(use_async_quorum=False) + manager = self._create_manager( + use_async_quorum=False, + init_sync=False, + ) + + self.assertFalse(manager._init_sync) quorum = QuorumResult() quorum.quorum_id = 123 @@ -633,16 +640,8 @@ def test_quorum_skip_init(self, client_mock: MagicMock) -> None: client_mock()._quorum.return_value = quorum manager.start_quorum() - self.assertEqual( - client_mock()._quorum.call_args.kwargs["init_sync"], True - ) + self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], False) - manager.start_quorum(init_sync=True) - self.assertEqual( - client_mock()._quorum.call_args.kwargs["init_sync"], True - ) - - manager.start_quorum(init_sync=False) - self.assertEqual( - client_mock()._quorum.call_args.kwargs["init_sync"], False - ) + manager._init_sync = True + manager.start_quorum() + self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)