diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py index d84a0861e984..c450eaf8a79f 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_wan.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_wan.py @@ -1289,6 +1289,10 @@ def tiled_encode(self, x: torch.Tensor) -> AutoencoderKLOutput: blend_height = tile_latent_min_height - tile_latent_stride_height blend_width = tile_latent_min_width - tile_latent_stride_width + # Apply patchify if patch_size is specified + if self.config.patch_size is not None: + x = patchify(x, patch_size=self.config.patch_size) + # Split x into overlapping tiles and encode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] @@ -1392,6 +1396,10 @@ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[Decod dec = torch.cat(result_rows, dim=3)[:, :, :, :sample_height, :sample_width] + # Apply unpatchify if patch_size is specified + if self.config.patch_size is not None: + dec = unpatchify(dec, patch_size=self.config.patch_size) + if not return_dict: return (dec,) return DecoderOutput(sample=dec)