Skip to content

Commit

Permalink
Fix Tiling in ConsistencyDecoderVAE (huggingface#7290)
Browse files Browse the repository at this point in the history
* Fix typos

* Add docstring to `decode` method in `ConsistencyDecoderVAE`

* Fix tiling

* Enable tiled VAE decoding with customizable tile sample size and overlap factor

* Revert "Enable tiled VAE decoding with customizable tile sample size and overlap factor"

This reverts commit 1810496.

* Add VAE tiling test for `ConsistencyDecoderVAE`

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
tolgacangoz and sayakpaul authored Mar 26, 2024
1 parent 288632a commit 443aa14
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 3 deletions.
33 changes: 30 additions & 3 deletions src/diffusers/models/autoencoders/consistency_decoder_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
... "runwayml/stable-diffusion-v1-5", vae=vae, torch_dtype=torch.float16
... ).to("cuda")
>>> pipe("horse", generator=torch.manual_seed(0)).images
>>> image = pipe("horse", generator=torch.manual_seed(0)).images[0]
>>> image
```
"""

Expand All @@ -72,6 +73,7 @@ def __init__(
self,
scaling_factor: float = 0.18215,
latent_channels: int = 4,
sample_size: int = 32,
encoder_act_fn: str = "silu",
encoder_block_out_channels: Tuple[int, ...] = (128, 256, 512, 512),
encoder_double_z: bool = True,
Expand Down Expand Up @@ -153,6 +155,16 @@ def __init__(
self.use_slicing = False
self.use_tiling = False

# only relevant if vae tiling is enabled
self.tile_sample_min_size = self.config.sample_size
sample_size = (
self.config.sample_size[0]
if isinstance(self.config.sample_size, (list, tuple))
else self.config.sample_size
)
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
self.tile_overlap_factor = 0.25

# Copied from diffusers.models.autoencoders.autoencoder_kl.AutoencoderKL.enable_tiling
def enable_tiling(self, use_tiling: bool = True):
r"""
Expand Down Expand Up @@ -272,7 +284,7 @@ def encode(
Args:
x (`torch.FloatTensor`): Input batch of images.
return_dict (`bool`, *optional*, defaults to `True`):
Whether to return a [`~models.consistecy_decoder_vae.ConsistencyDecoderOoutput`] instead of a plain
Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain
tuple.
Returns:
Expand Down Expand Up @@ -305,6 +317,19 @@ def decode(
return_dict: bool = True,
num_inference_steps: int = 2,
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]:
"""
Decodes the input latent vector `z` using the consistency decoder VAE model.
Args:
z (torch.FloatTensor): The input latent vector.
generator (Optional[torch.Generator]): The random number generator. Default is None.
return_dict (bool): Whether to return the output as a dictionary. Default is True.
num_inference_steps (int): The number of inference steps. Default is 2.
Returns:
Union[DecoderOutput, Tuple[torch.FloatTensor]]: The decoded output.
"""
z = (z * self.config.scaling_factor - self.means) / self.stds

scale_factor = 2 ** (len(self.config.block_out_channels) - 1)
Expand Down Expand Up @@ -345,7 +370,9 @@ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
return b

def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> ConsistencyDecoderVAEOutput:
def tiled_encode(
self, x: torch.FloatTensor, return_dict: bool = True
) -> Union[ConsistencyDecoderVAEOutput, Tuple]:
r"""Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
Expand Down
30 changes: 30 additions & 0 deletions tests/models/autoencoders/test_models_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,3 +1116,33 @@ def test_sd_f16(self):
)

assert torch_all_close(actual_output, expected_output, atol=5e-3)

def test_vae_tiling(self):
vae = ConsistencyDecoderVAE.from_pretrained("openai/consistency-decoder")
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", vae=vae, safety_checker=None)
pipe.to(torch_device)
pipe.set_progress_bar_config(disable=None)

out_1 = pipe(
"horse",
num_inference_steps=2,
output_type="pt",
generator=torch.Generator("cpu").manual_seed(0),
).images[0]

# make sure tiled vae decode yields the same result
pipe.enable_vae_tiling()
out_2 = pipe(
"horse",
num_inference_steps=2,
output_type="pt",
generator=torch.Generator("cpu").manual_seed(0),
).images[0]

assert torch_all_close(out_1, out_2, atol=5e-3)

# test that tiled decode works with various shapes
shapes = [(1, 4, 73, 97), (1, 4, 97, 73), (1, 4, 49, 65), (1, 4, 65, 49)]
for shape in shapes:
image = torch.zeros(shape, device=torch_device)
pipe.vae.decode(image)

0 comments on commit 443aa14

Please sign in to comment.