Skip to content

fix gradient allreduce #215

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 40 additions & 28 deletions torchft/local_sgd.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,17 +257,41 @@ def restore_parameters(self) -> None:
else:
p.data.copy_(self.original_parameters[name], non_blocking=False)

def _save_grads(self) -> None:
"""
Saves pseudo-gradients of the parameters
"""
with torch.no_grad():
for name, p in self._model_fragment.named_parameters():
if isinstance(p, DTensor):
local_param = p.to_local()
else:
local_param = p
pseudogradient = local_param - self.original_parameters[name].to(
p.device
)
self._grads[name] = pseudogradient

def _set_grads(self) -> None:
"""
Sets the gradients of the model fragment from the allreduce result
"""
for name, p in self._model_fragment.named_parameters():
if isinstance(p, DTensor):
p.grad._local_tensor = self._grads[name]
else:
p.grad = self._grads[name]
with torch.no_grad():
for name, p in self._model_fragment.named_parameters():
# avoid copying the gradient, it should be on the same device
if isinstance(p, DTensor):
p.grad = DTensor.from_local(
self._grads[name],
p.device_mesh,
p.placements,
shape=p.shape,
stride=p.stride(),
)
else:
p.grad = self._grads[name]

del self._grads[name]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how come del was removed?

# No longer needed
del self._grads[name]

@torch.profiler.record_function("torchft::local_sgd::wait")
def wait(self) -> None:
Expand Down Expand Up @@ -304,14 +328,9 @@ def prepare_sync(self) -> None:
Calculate the pseugradient, average them across the manager group and starts
allreduce on the pseudo-gradients but doesn't wait for it to finish.
"""
# Set the .grad field of each parameter to its pseudogradient
for name, p in self._model_fragment.named_parameters():
local_param = extract_local_tensor(p.data)
pseudogradient = local_param - self.original_parameters[name].to(p.device)
if isinstance(p, DTensor):
self._grads[name] = pseudogradient
else:
self._grads[name] = pseudogradient
self._save_grads()

assert len(self._allreduce_futures) == 0

# Make sure tensors are available to `_stream`
if self._stream is not None:
Expand Down Expand Up @@ -371,18 +390,12 @@ def _allreduce_per_param(self) -> None:
"""Performs allreduce on each gradient tensor separately (original method)."""
for name, p in self._model_fragment.named_parameters():
# Perform allreduce on the pseudogradients
assert p.grad is not None
if isinstance(p, DTensor):
work = self._manager.allreduce(
self._grads[name], should_quantize=self.should_quantize
)
else:
work = self._manager.allreduce(
self._grads[name], should_quantize=self.should_quantize
)
work = self._manager.allreduce(
self._grads[name], should_quantize=self.should_quantize
)
self._allreduce_futures.append(work)

def bucketize_and_allreduce(
def _bucketize_and_allreduce(
self,
tensors: List[torch.Tensor],
bucket_size_bytes: int,
Expand Down Expand Up @@ -439,10 +452,9 @@ def _allreduce_bucketized(self) -> None:
"""
Averages gradients using bucketized allreduce with a fixed buffer.
"""
grads = [
p.grad for p in self._model_fragment.parameters() if p.grad is not None
]
self.bucketize_and_allreduce(
grads = list(self._grads.values())
assert len(grads) > 0, "No gradients to allreduce"
self._bucketize_and_allreduce(
grads,
bucket_size_bytes=self.bucket_cap_mb,
)
Expand Down
85 changes: 71 additions & 14 deletions torchft/local_sgd_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,16 @@ def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Ten
return {name: value.clone().detach() for name, value in state_dict.items()}


class TinyModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.w1 = nn.Parameter(torch.tensor([1.0, 2.0]))
self.w2 = nn.Parameter(torch.tensor([3.0, 4.0, 5.0]))

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x @ self.w1.unsqueeze(0).T + self.w2.sum()


class LocalSGDTest(TestCase):
def test_local_sgd_healthy(self) -> None:
model = SimpleModel()
Expand Down Expand Up @@ -216,24 +226,10 @@ def test_diloco_allreduce_call_efficiency(
self.assertEqual(int(allreduce_calls), int(param_count))

def test_bucketization_correctness(self) -> None:
class TinyModel(nn.Module):
def __init__(self):
super().__init__()
self.w1 = nn.Parameter(torch.tensor([1.0, 2.0]))
self.w2 = nn.Parameter(torch.tensor([3.0, 4.0, 5.0]))

def forward(self, x):
return x @ self.w1.unsqueeze(0).T + self.w2.sum()

model = TinyModel()
inner_opt = torch.optim.SGD(model.parameters(), lr=0.1)
outer_opt = torch.optim.SGD(model.parameters(), lr=0.1)

# Manually assign fake gradients
grads = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]
for p, g in zip(model.parameters(), grads):
p.grad = g.clone()

manager = create_autospec(Manager)
manager._use_async_quorum = False
manager.should_commit.return_value = True
Expand All @@ -254,10 +250,71 @@ def fake_allreduce(
)
diloco._fragments[0].bucket_cap_mb = 10 * 1024 * 1024

# Manually assign fake gradients
grads = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]
for g, (name, param) in zip(grads, model.named_parameters()):
diloco._fragments[0]._grads[name] = g.clone()

# Run only bucketized logic
diloco._fragments[0]._average_grads()

# The parameter gradients should not be set
for param in model.parameters():
self.assertEqual(param.grad, None)

diloco._fragments[0]._set_grads()

# Expect grads to have been doubled
expected_grads = [g * 2 for g in grads]
for param, expected in zip(model.parameters(), expected_grads):
torch.testing.assert_close(param.grad, expected, rtol=1e-5, atol=1e-8)

def test_gradient_correctness(self) -> None:
model = TinyModel()
inner_opt = torch.optim.SGD(model.parameters(), lr=0.1)
outer_opt = torch.optim.SGD(model.parameters(), lr=0.1)

manager = create_autospec(Manager)
manager._use_async_quorum = False
manager.should_commit.return_value = True

# Define fake allreduce: multiplies buffer by 2
def fake_allreduce(
tensor: Tensor, should_quantize: bool
) -> torch.futures.Future[Tensor]:
tensor.mul_(2)
fut = torch.futures.Future() # pyre-fixme[29]: not a function
fut.set_result(tensor)
return fut

manager.allreduce.side_effect = fake_allreduce

diloco = DiLoCo(manager, [model], inner_opt, outer_opt, sync_every=2)

# save original parameters
diloco._fragments[0].save_parameters()

# change the model's parameters
for p in model.parameters():
p.data.add_(2)

# calculate and set the gradients
diloco._fragments[0]._save_grads()

# calculate
diloco._fragments[0]._average_grads()

# The parameter gradients should not be set
for param in model.parameters():
self.assertEqual(param.grad, None)

diloco._fragments[0]._set_grads()

# we added 2 to the parameters, then multiplied the gradients by 2
# so we should expect the model's gradient to be 4
expected_grad = 4
for param in model.parameters():
assert param.grad is not None
t = torch.empty_like(param.grad)
t.fill_(expected_grad)
torch.testing.assert_close(param.grad, t)
Loading