Skip to content

enable merging parameters for diloco #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,8 +213,14 @@ def __init__(
self.should_quantize = should_quantize

self._grads: Dict[str, torch.Tensor] = {}

# Used to save global parameters so that they can be restored in case
# commit fails
self.original_parameters: Dict[str, torch.Tensor] = {}

# Used to mix the local and global parameters
self._local_parameters: Dict[str, torch.Tensor] = {}

for name, p in self._model_fragment.named_parameters():
if isinstance(p, DTensor):
p = extract_local_tensor(p.data)
Expand All @@ -237,6 +243,14 @@ def save_parameters(self) -> None:
param_to_local = extract_local_tensor(p.data)
self.original_parameters[name].copy_(param_to_local, non_blocking=True)

def _save_local_parameters(self) -> None:
"""
Saves a copy of the model's parameters.
"""
with torch.no_grad():
for name, p in self._model_fragment.named_parameters():
self._local_parameters[name] = extract_local_tensor(p.data)

@torch.profiler.record_function("torchft::local_sgd::restore_parameters")
def restore_parameters(self) -> None:
with torch.no_grad():
Expand Down Expand Up @@ -293,6 +307,21 @@ def _set_grads(self) -> None:
# No longer needed
del self._grads[name]

def _clear_local_parameters(self) -> None:
"""
Clears the saved copy of the model's parameters
"""
self._local_parameters = {}

def _merge_parameters(self) -> None:
"""
Merges the local and global parameters.
"""
for name, p in self._model_fragment.named_parameters():
torch.lerp(
p.data, self._local_parameters[name], 1 - self._fragment_update_alpha
)

@torch.profiler.record_function("torchft::local_sgd::wait")
def wait(self) -> None:
"""
Expand Down Expand Up @@ -382,6 +411,8 @@ def perform_sync(self) -> bool:

self.wait()

# save the parameters so they can be used for merging
self._save_local_parameters()
# Restore the parameters back to the previous state
self.restore_parameters()

Expand All @@ -404,8 +435,12 @@ def perform_sync(self) -> bool:
self._set_grads()
self._outer_optimizer.step()
self.save_parameters()
self._merge_parameters()
self._outer_optimizer.zero_grad()

# free up memory
self._clear_local_parameters()

return should_commit

def _average_grads(self) -> None:
Expand Down Expand Up @@ -557,12 +592,6 @@ def __init__(
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
raise ValueError("fragment_update_alpha must be between 0 and 1")

# TODO: Support `fragment_update_alpha`
if fragment_update_alpha != 0.0:
raise ValueError(
"Merging local parameters with global parameters is not supported yet"
)

super().__init__()
self._manager = manager

Expand Down