Skip to content

Commit

Permalink
[refactor] Fix FreeInit behaviour (huggingface#7410)
Browse files Browse the repository at this point in the history
* fix freeinit impl

* fix progress bar

* fix progress bar and remove old code

* fix num_inference_steps==1 case for freeinit by atleast running 1 step when fast sampling enabled
  • Loading branch information
a-r-r-o-w authored Mar 22, 2024
1 parent 9613576 commit 3636990
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 105 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ def __call__(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -944,7 +944,7 @@ def __call__(
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order

# 8. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
Expand Down
61 changes: 31 additions & 30 deletions src/diffusers/pipelines/free_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,39 +146,40 @@ def _apply_free_init(
):
if free_init_iteration == 0:
self._free_init_initial_noise = latents.detach().clone()
return latents, self.scheduler.timesteps

latent_shape = latents.shape

free_init_filter_shape = (1, *latent_shape[1:])
free_init_freq_filter = self._get_free_init_freq_filter(
shape=free_init_filter_shape,
device=device,
filter_type=self._free_init_method,
order=self._free_init_order,
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
)

current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()

z_t = self.scheduler.add_noise(
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
).to(dtype=torch.float32)

z_rand = randn_tensor(
shape=latent_shape,
generator=generator,
device=device,
dtype=torch.float32,
)
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
latents = latents.to(dtype)
else:
latent_shape = latents.shape

free_init_filter_shape = (1, *latent_shape[1:])
free_init_freq_filter = self._get_free_init_freq_filter(
shape=free_init_filter_shape,
device=device,
filter_type=self._free_init_method,
order=self._free_init_order,
spatial_stop_frequency=self._free_init_spatial_stop_frequency,
temporal_stop_frequency=self._free_init_temporal_stop_frequency,
)

current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()

z_t = self.scheduler.add_noise(
original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
).to(dtype=torch.float32)

z_rand = randn_tensor(
shape=latent_shape,
generator=generator,
device=device,
dtype=torch.float32,
)
latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
latents = latents.to(dtype)

# Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
if self._free_init_use_fast_sampling:
num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
num_inference_steps = max(
1, int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
)
self.scheduler.set_timesteps(num_inference_steps, device=device)

return latents, self.scheduler.timesteps
82 changes: 9 additions & 73 deletions src/diffusers/pipelines/pia/pipeline_pia.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,12 @@
# limitations under the License.

import inspect
import math
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import PIL
import torch
import torch.fft as fft
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import PipelineImageInput, VaeImageProcessor
Expand Down Expand Up @@ -130,81 +128,16 @@ def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_sca
return coef


def _get_freeinit_freq_filter(
shape: Tuple[int, ...],
device: Union[str, torch.dtype],
filter_type: str,
order: float,
spatial_stop_frequency: float,
temporal_stop_frequency: float,
) -> torch.Tensor:
r"""Returns the FreeInit filter based on filter type and other input conditions."""

time, height, width = shape[-3], shape[-2], shape[-1]
mask = torch.zeros(shape)

if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
return mask

if filter_type == "butterworth":

def retrieve_mask(x):
return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
elif filter_type == "gaussian":

def retrieve_mask(x):
return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
elif filter_type == "ideal":

def retrieve_mask(x):
return 1 if x <= spatial_stop_frequency * 2 else 0
else:
raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")

for t in range(time):
for h in range(height):
for w in range(width):
d_square = (
((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
+ (2 * h / height - 1) ** 2
+ (2 * w / width - 1) ** 2
)
mask[..., t, h, w] = retrieve_mask(d_square)

return mask.to(device)


def _freq_mix_3d(x: torch.Tensor, noise: torch.Tensor, LPF: torch.Tensor) -> torch.Tensor:
r"""Noise reinitialization."""
# FFT
x_freq = fft.fftn(x, dim=(-3, -2, -1))
x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))

# frequency mix
HPF = 1 - LPF
x_freq_low = x_freq * LPF
noise_freq_high = noise_freq * HPF
x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain

# IFFT
x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real

return x_mixed


@dataclass
class PIAPipelineOutput(BaseOutput):
r"""
Output class for PIAPipeline.
Args:
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
NumPy array of shape `(batch_size, num_frames, channels, height, width,
Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
Nested list of length `batch_size` with denoised PIL image sequences of length `num_frames`,
NumPy array of shape `(batch_size, num_frames, channels, height, width,
Torch tensor of shape `(batch_size, num_frames, channels, height, width)`.
"""

frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
Expand Down Expand Up @@ -788,7 +721,8 @@ def __call__(
The input image to be used for video generation.
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
strength (`float`, *optional*, defaults to 1.0): Indicates extent to transform the reference `image`. Must be between 0 and 1.
strength (`float`, *optional*, defaults to 1.0):
Indicates extent to transform the reference `image`. Must be between 0 and 1.
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
The height in pixels of the generated video.
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
Expand Down Expand Up @@ -979,8 +913,10 @@ def __call__(
latents, free_init_iter, num_inference_steps, device, latents.dtype, generator
)

self._num_timesteps = len(timesteps)
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:

with self.progress_bar(total=self._num_timesteps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
Expand Down

0 comments on commit 3636990

Please sign in to comment.