diff --git a/torchft/local_sgd.py b/torchft/local_sgd.py index 1e409a7..de03452 100644 --- a/torchft/local_sgd.py +++ b/torchft/local_sgd.py @@ -213,8 +213,14 @@ def __init__( self.should_quantize = should_quantize self._grads: Dict[str, torch.Tensor] = {} + + # Used to save global parameters so that they can be restored in case + # commit fails self.original_parameters: Dict[str, torch.Tensor] = {} + # Used to mix the local and global parameters + self._local_parameters: Dict[str, torch.Tensor] = {} + for name, p in self._model_fragment.named_parameters(): if isinstance(p, DTensor): p = extract_local_tensor(p.data) @@ -237,6 +243,14 @@ def save_parameters(self) -> None: param_to_local = extract_local_tensor(p.data) self.original_parameters[name].copy_(param_to_local, non_blocking=True) + def _save_local_parameters(self) -> None: + """ + Saves a copy of the model's parameters. + """ + with torch.no_grad(): + for name, p in self._model_fragment.named_parameters(): + self._local_parameters[name] = extract_local_tensor(p.data) + @torch.profiler.record_function("torchft::local_sgd::restore_parameters") def restore_parameters(self) -> None: with torch.no_grad(): @@ -293,6 +307,21 @@ def _set_grads(self) -> None: # No longer needed del self._grads[name] + def _clear_local_parameters(self) -> None: + """ + Clears the saved copy of the model's parameters + """ + self._local_parameters = {} + + def _merge_parameters(self) -> None: + """ + Merges the local and global parameters. + """ + for name, p in self._model_fragment.named_parameters(): + torch.lerp( + p.data, self._local_parameters[name], 1 - self._fragment_update_alpha + ) + @torch.profiler.record_function("torchft::local_sgd::wait") def wait(self) -> None: """ @@ -382,6 +411,8 @@ def perform_sync(self) -> bool: self.wait() + # save the parameters so they can be used for merging + self._save_local_parameters() # Restore the parameters back to the previous state self.restore_parameters() @@ -404,8 +435,12 @@ def perform_sync(self) -> bool: self._set_grads() self._outer_optimizer.step() self.save_parameters() + self._merge_parameters() self._outer_optimizer.zero_grad() + # free up memory + self._clear_local_parameters() + return should_commit def _average_grads(self) -> None: @@ -557,12 +592,6 @@ def __init__( if fragment_update_alpha < 0 or fragment_update_alpha > 1: raise ValueError("fragment_update_alpha must be between 0 and 1") - # TODO: Support `fragment_update_alpha` - if fragment_update_alpha != 0.0: - raise ValueError( - "Merging local parameters with global parameters is not supported yet" - ) - super().__init__() self._manager = manager