Skip to content

Fix EDMEulerScheduler training with integer timesteps #11991

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/diffusers/modular_pipelines/modular_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def decode_load_id(cls, load_id: str) -> Dict[str, Optional[str]]:

# Get all loading fields in order
loading_fields = cls.loading_fields()
result = {f: None for f in loading_fields}
result = dict.fromkeys(loading_fields)

if load_id == "null":
return result
Expand Down
38 changes: 22 additions & 16 deletions src/diffusers/schedulers/scheduling_edm_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,6 @@ def step(

return EDMEulerSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)

# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
def add_noise(
self,
original_samples: torch.Tensor,
Expand All @@ -415,23 +414,30 @@ def add_noise(
) -> torch.Tensor:
# Make sure sigmas and timesteps have the same device and dtype as original_samples
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)
timesteps = timesteps.to(original_samples.device)

# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = [self.step_index] * timesteps.shape[0]
# Handle integer timesteps (training case)
if timesteps.dtype in (torch.int32, torch.int64):
# Training: reverse mapping since EDM sigmas are in descending order
# timestep 0 -> sigma_min, timestep 999 -> sigma_max
step_indices = self.config.num_train_timesteps - 1 - timesteps.long()
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = [self.begin_index] * timesteps.shape[0]
if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
# mps does not support float64
schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32)
timesteps = timesteps.to(original_samples.device, dtype=torch.float32)
else:
schedule_timesteps = self.timesteps.to(original_samples.device)

# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
if self.begin_index is None:
step_indices = torch.tensor([self.index_for_timestep(t, schedule_timesteps) for t in timesteps])
elif self.step_index is not None:
# add_noise is called after first denoising step (for inpainting)
step_indices = torch.tensor([self.step_index] * timesteps.shape[0])
else:
# add noise is called before first denoising step to create initial latent(img2img)
step_indices = torch.tensor([self.begin_index] * timesteps.shape[0])

sigma = sigmas[step_indices].flatten()
while len(sigma.shape) < len(original_samples.shape):
Expand Down
63 changes: 61 additions & 2 deletions tests/schedulers/test_scheduler_edm_euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,67 @@ def test_full_loop_device(self, num_inference_steps=10, seed=0):
assert abs(result_sum.item() - 34.1855) < 1e-3
assert abs(result_mean.item() - 0.044) < 1e-3

def test_add_noise_with_integer_timesteps(self):
"""Test that add_noise works with integer timesteps (training case) - Issue #7406"""
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

batch_size = 4
channels = 3
height = width = 32

# Create dummy data
original_samples = torch.randn(batch_size, channels, height, width)
noise = torch.randn_like(original_samples)

# Test with integer timesteps (training case)
timesteps = torch.randint(0, scheduler_config["num_train_timesteps"], (batch_size,), dtype=torch.int64)

# This should not raise an error
noisy_samples = scheduler.add_noise(original_samples, noise, timesteps)

# Verify output shape
self.assertEqual(noisy_samples.shape, original_samples.shape)

# Verify that noise was actually added
self.assertFalse(torch.allclose(noisy_samples, original_samples))

# Verify noise levels are correct (higher timestep = more noise)
t_low = torch.tensor([0], dtype=torch.int64)
t_high = torch.tensor([scheduler_config["num_train_timesteps"] - 1], dtype=torch.int64)

noisy_low = scheduler.add_noise(original_samples[:1], noise[:1], t_low)
noisy_high = scheduler.add_noise(original_samples[:1], noise[:1], t_high)

noise_low = (noisy_low - original_samples[:1]).abs().mean()
noise_high = (noisy_high - original_samples[:1]).abs().mean()

# Higher timestep should have more noise
self.assertGreater(noise_high, noise_low)

def test_add_noise_with_float_timesteps(self):
"""Test that add_noise still works with float timesteps (inference case)"""
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(50)

batch_size = 2

# Create dummy data with correct shape
original_samples = torch.randn(batch_size, 3, 32, 32)
noise = torch.randn_like(original_samples)

# Use float timesteps from the scheduler
timesteps = scheduler.timesteps[:batch_size]

# This should not raise an error
noisy_samples = scheduler.add_noise(original_samples, noise, timesteps)

# Verify output shape
self.assertEqual(noisy_samples.shape, original_samples.shape)

# Override test_from_save_pretrained to use EDMEulerScheduler-specific logic
def test_from_save_pretrained(self):
kwargs = dict(self.forward_default_kwargs)
Expand Down Expand Up @@ -115,7 +176,6 @@ def test_from_save_pretrained(self):

assert torch.sum(torch.abs(output - new_output)) < 1e-5, "Scheduler outputs are not identical"

# Override test_from_save_pretrained to use EDMEulerScheduler-specific logic
def test_step_shape(self):
num_inference_steps = 10

Expand All @@ -137,7 +197,6 @@ def test_step_shape(self):
self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)

# Override test_from_save_pretrained to use EDMEulerScheduler-specific logic
def test_scheduler_outputs_equivalence(self):
def set_nan_tensor_to_zero(t):
t[t != t] = 0
Expand Down