diff --git a/proto/torchft.proto b/proto/torchft.proto index c4d0a81..15a96d0 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -76,6 +76,7 @@ message ManagerQuorumRequest { int64 step = 2; string checkpoint_metadata = 3; bool shrink_only = 4; + bool init_sync = 5; } message ManagerQuorumResponse { diff --git a/src/lib.rs b/src/lib.rs index a29e00d..5ef1bcf 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -175,6 +175,7 @@ impl ManagerClient { step: i64, checkpoint_metadata: String, shrink_only: bool, + init_sync: bool, timeout: Duration, ) -> Result { py.allow_threads(move || { @@ -183,6 +184,7 @@ impl ManagerClient { step: step, checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, + init_sync: init_sync, }); // This timeout is processed on the server side so we also enable diff --git a/src/manager.rs b/src/manager.rs index 08d0cc2..358ff7a 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -286,7 +286,7 @@ impl ManagerService for Arc { info_with_replica!(self.replica_id, "Finished quorum for rank {}", rank); - let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?; + let reply = compute_quorum_results(&self.replica_id, rank, &quorum, req.init_sync)?; Ok(Response::new(reply)) } @@ -382,6 +382,7 @@ fn compute_quorum_results( replica_id: &str, rank: i64, quorum: &Quorum, + init_sync: bool, ) -> Result { let mut participants = quorum.participants.clone(); participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id)); @@ -424,20 +425,23 @@ fn compute_quorum_results( // Compute recovery assignments - // Nodes are recovering if: - // 1. not at the max step + let force_recover = init_sync && max_step == 0; + + // Nodes are recovering if + // 1. not at the max step (init_sync) // 2. max_step == 0 and not the primary replica let all_recover_dst_ranks: Vec = participants .iter() .enumerate() .filter_map(|(i, p)| { - if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { + if p.step != max_step || force_recover && primary.replica_id != p.replica_id { Some(i) } else { None } }) .collect(); + let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::>(); let up_to_date_ranks: Vec = participants .iter() @@ -605,6 +609,7 @@ mod tests { step: 123, checkpoint_metadata: "addr".to_string(), shrink_only: false, + init_sync: true, }); request.set_timeout(Duration::from_secs(10)); let resp = client.quorum(request).await?.into_inner(); @@ -664,6 +669,7 @@ mod tests { step: 0, checkpoint_metadata: "addr".to_string(), shrink_only: false, + init_sync: true, }); request.set_timeout(Duration::from_secs(10)); @@ -771,13 +777,13 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum)?; + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![1]); - let results = compute_quorum_results("replica_1", 0, &quorum)?; + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, Some(0)); @@ -785,7 +791,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above and the primary - let results = compute_quorum_results("replica_1", 1, &quorum)?; + let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); @@ -850,21 +856,21 @@ mod tests { // rank 0 - let results = compute_quorum_results("replica_0", 0, &quorum)?; + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; assert!(results.heal); assert_eq!(results.recover_src_manager_address, "addr_1".to_string()); assert_eq!(results.replica_rank, 0); assert_eq!(results.recover_src_rank, Some(1)); assert!(results.recover_dst_ranks.is_empty()); - let results = compute_quorum_results("replica_1", 0, &quorum)?; + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.recover_src_manager_address, "".to_string()); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); assert_eq!(results.recover_dst_ranks, vec![0, 4]); - let results = compute_quorum_results("replica_3", 0, &quorum)?; + let results = compute_quorum_results("replica_3", 0, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 3); assert_eq!(results.recover_src_rank, None); @@ -872,7 +878,7 @@ mod tests { // rank 1 assignments should be offset from rank 0 above - let results = compute_quorum_results("replica_1", 1, &quorum)?; + let results = compute_quorum_results("replica_1", 1, &quorum, true)?; assert!(!results.heal); assert_eq!(results.replica_rank, 1); assert_eq!(results.recover_src_rank, None); @@ -880,4 +886,50 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_compute_quorum_results_skip_init_sync() -> Result<()> { + let mut quorum = Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "replica_0".to_string(), + address: "addr_0".to_string(), + store_address: "store_addr_0".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + data: String::new(), + }, + QuorumMember { + replica_id: "replica_1".to_string(), + address: "addr_1".to_string(), + store_address: "store_addr_1".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + data: String::new(), + }, + ], + created: None, + }; + + // baseline w/ init_sync=true + let results = compute_quorum_results("replica_0", 0, &quorum, true)?; + assert!(!results.heal); + + let results = compute_quorum_results("replica_1", 0, &quorum, true)?; + assert!(results.heal); + + // init_sync=false + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + assert!(!results.heal); + + // init_sync=false, step=1 + quorum.participants[0].step = 1; + let results = compute_quorum_results("replica_1", 0, &quorum, false)?; + assert!(results.heal); + + Ok(()) + } } diff --git a/torchft/_torchft.pyi b/torchft/_torchft.pyi index 49bdcdd..fdbd1fa 100644 --- a/torchft/_torchft.pyi +++ b/torchft/_torchft.pyi @@ -11,6 +11,7 @@ class ManagerClient: checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, + init_sync: bool = True, ) -> QuorumResult: ... def _checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( diff --git a/torchft/manager.py b/torchft/manager.py index 0697bd4..fa2760d 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -106,6 +106,7 @@ def __init__( hostname: str = socket.gethostname(), heartbeat_interval: timedelta = timedelta(milliseconds=100), checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None, + init_sync: bool = True, ) -> None: """ Args: @@ -143,6 +144,9 @@ def __init__( hostname: if rank==0, the hostname to advertise to the lighthouse server checkpoint_transport: the checkpoint transport to use for transfering checkpoints to recovering replicas, defaults to HTTPTransport + init_sync: whether to synchronize the model weights on step 0. If + all of the model weights are initialized identically via + ``torch.set_seed`` you should set this to False. """ self._load_state_dict = load_state_dict self._user_state_dict = state_dict @@ -152,6 +156,7 @@ def __init__( self._quorum_timeout = quorum_timeout self._connect_timeout = connect_timeout self._world_size_mode = world_size_mode + self._init_sync = init_sync store_addr = store_addr or os.environ["MASTER_ADDR"] store_port = store_port or int(os.environ["MASTER_PORT"]) @@ -455,6 +460,7 @@ def _async_quorum( checkpoint_metadata=self._checkpoint_transport.metadata(), shrink_only=shrink_only, timeout=quorum_timeout, + init_sync=self._init_sync, ) quorum_id = quorum.quorum_id diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index d591d0d..e7622be 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -25,6 +25,8 @@ logger: logging.Logger = logging.getLogger(__name__) +INIT_LOCK: threading.Lock = threading.Lock() + class MyModel(nn.Module): def __init__(self, in_dim: int = 3, out_dim: int = 4) -> None: @@ -191,7 +193,13 @@ def state_dict() -> Dict[str, Dict[str, object]]: ) stack.callback(lambda: manager.shutdown(wait=False)) - m: nn.Module = DistributedDataParallel(manager, MyModel()) + with INIT_LOCK: + # We need to lock during init for testing init_sync=False as all + # threads share the same RNG + torch.manual_seed(42) + m: nn.Module = MyModel() + + m: nn.Module = DistributedDataParallel(manager, m) optimizer: optim.Optimizer = OptimizerWrapper( manager, optim.Adam(m.parameters()) ) @@ -270,7 +278,11 @@ def test_ddp_healthy(self) -> None: ), ] ) - def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: + def test_ddp_recovery( + self, + name: str, + use_async_quorum: bool, + ) -> None: lighthouse = LighthouseServer( bind="[::]:0", min_replicas=2, @@ -302,7 +314,11 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: state_dicts = [] for fut in as_completed(futures): - state_dicts.append(fut.result()) + try: + state_dicts.append(fut.result()) + except Exception as e: + print(e) + raise lighthouse.shutdown() @@ -311,6 +327,53 @@ def test_ddp_recovery(self, name: str, use_async_quorum: bool) -> None: self.assertEqual(failure_injectors[1].count, 1) + def test_ddp_skip_init_sync( + self, + ) -> None: + lighthouse = LighthouseServer( + bind="[::]:0", + min_replicas=2, + ) + num_replicas = 2 + futures = [] + + # no failures + failure_injectors = [ + FailureInjector(), + FailureInjector(), + ] + + with ThreadPoolExecutor(max_workers=num_replicas) as executor: + for replica_id, failure_injector in zip( + range(num_replicas), failure_injectors + ): + runner = Runner( + replica_id=replica_id, + num_replicas=num_replicas, + lighthouse_address=lighthouse.address(), + failure_injector=failure_injector, + manager_args={ + "use_async_quorum": False, + "init_sync": False, + }, + train_loop=ddp_train_loop, + ) + futures.append(executor.submit(runner.run_replica)) + + state_dicts = [] + + for fut in as_completed(futures): + try: + state_dicts.append(fut.result()) + except Exception as e: + print(e) + raise + + lighthouse.shutdown() + + for state_dict in state_dicts: + torch.testing.assert_close(state_dict, state_dicts[0]) + def test_ddp_recovery_multi_rank(self) -> None: lighthouse = LighthouseServer( bind="[::]:0", diff --git a/torchft/manager_test.py b/torchft/manager_test.py index fb13496..2d421e6 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -40,6 +40,7 @@ def _create_manager( min_replica_size: int = 2, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, timeout: timedelta = timedelta(seconds=10), + init_sync: bool = True, ) -> Manager: pg = create_autospec(ProcessGroup) pg.errored.return_value = None @@ -67,6 +68,7 @@ def _create_manager( use_async_quorum=use_async_quorum, world_size_mode=world_size_mode, timeout=timeout, + init_sync=init_sync, ) self.manager = manager return manager @@ -614,3 +616,32 @@ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: client_mock().should_commit.call_args.kwargs["timeout"], timedelta(seconds=23), ) + + @patch("torchft.manager.ManagerClient", autospec=True) + def test_quorum_skip_init(self, client_mock: MagicMock) -> None: + manager = self._create_manager( + use_async_quorum=False, + init_sync=False, + ) + + self.assertFalse(manager._init_sync) + + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock()._quorum.return_value = quorum + + manager.start_quorum() + self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], False) + + manager._init_sync = True + manager.start_quorum() + self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)