From f6f7afa1d7c6f45f8568c5603b1e6300d4583f04 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Wed, 20 Nov 2024 17:30:17 +0530 Subject: [PATCH] Flux latents fix (#9929) * update * update * update * update * update * update --------- Co-authored-by: Sayak Paul --- src/diffusers/pipelines/flux/pipeline_flux.py | 22 ++++++++----- .../flux/pipeline_flux_controlnet.py | 22 ++++++++----- ...pipeline_flux_controlnet_image_to_image.py | 24 ++++++++------ .../pipeline_flux_controlnet_inpainting.py | 32 +++++++++++-------- .../pipelines/flux/pipeline_flux_img2img.py | 23 +++++++------ .../pipelines/flux/pipeline_flux_inpaint.py | 32 +++++++++++-------- .../controlnet_flux/test_controlnet_flux.py | 22 +++++++++++++ .../test_controlnet_flux_img2img.py | 29 +++++++++++++++++ .../test_controlnet_flux_inpaint.py | 32 +++++++++++++++++++ tests/pipelines/flux/test_pipeline_flux.py | 14 ++++++++ .../flux/test_pipeline_flux_img2img.py | 14 ++++++++ .../flux/test_pipeline_flux_inpaint.py | 14 ++++++++ 12 files changed, 219 insertions(+), 61 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux.py b/src/diffusers/pipelines/flux/pipeline_flux.py index 040d935f1b88..12996f3f3e92 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux.py +++ b/src/diffusers/pipelines/flux/pipeline_flux.py @@ -197,7 +197,9 @@ def __init__( self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) @@ -386,9 +388,9 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: - raise ValueError( - f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) if callback_on_step_end_tensor_inputs is not None and not all( @@ -451,8 +453,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape - height = height // vae_scale_factor - width = width // vae_scale_factor + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) @@ -501,8 +505,10 @@ def prepare_latents( generator, latents=None, ): - height = int(height) // self.vae_scale_factor - width = int(width) // self.vae_scale_factor + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 771150b085d5..904173852ee4 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -218,7 +218,9 @@ def __init__( self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) @@ -410,9 +412,9 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): - if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: - raise ValueError( - f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) if callback_on_step_end_tensor_inputs is not None and not all( @@ -478,8 +480,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape - height = height // vae_scale_factor - width = width // vae_scale_factor + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) @@ -500,8 +504,10 @@ def prepare_latents( generator, latents=None, ): - height = int(height) // self.vae_scale_factor - width = int(width) // self.vae_scale_factor + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 04582b71d780..5d65df0b768e 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -230,7 +230,9 @@ def __init__( self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) @@ -453,9 +455,9 @@ def check_inputs( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: - raise ValueError( - f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) if callback_on_step_end_tensor_inputs is not None and not all( @@ -521,8 +523,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape - height = height // vae_scale_factor - width = width // vae_scale_factor + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) @@ -551,9 +555,10 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = int(height) // self.vae_scale_factor - width = int(width) // self.vae_scale_factor - + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) @@ -873,7 +878,6 @@ def __call__( timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) - latents, latent_image_ids = self.prepare_latents( init_image, latent_timestep, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index 947e97e272f8..5d5c8f73762c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -233,9 +233,11 @@ def __init__( self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.vae.config.latent_channels, do_normalize=False, do_binarize=True, @@ -467,9 +469,9 @@ def check_inputs( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: - raise ValueError( - f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) if callback_on_step_end_tensor_inputs is not None and not all( @@ -548,8 +550,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape - height = height // vae_scale_factor - width = width // vae_scale_factor + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) @@ -578,9 +582,10 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = int(height) // self.vae_scale_factor - width = int(width) // self.vae_scale_factor - + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) @@ -624,8 +629,10 @@ def prepare_mask_latents( device, generator, ): - height = int(height) // self.vae_scale_factor - width = int(width) // self.vae_scale_factor + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision @@ -663,7 +670,6 @@ def prepare_mask_latents( # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - masked_image_latents = self._pack_latents( masked_image_latents, batch_size, diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index 4fbac51eadb1..d34d9b53aa6b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -214,7 +214,9 @@ def __init__( self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.tokenizer_max_length = ( self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77 ) @@ -437,9 +439,9 @@ def check_inputs( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: - raise ValueError( - f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) if callback_on_step_end_tensor_inputs is not None and not all( @@ -505,8 +507,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape - height = height // vae_scale_factor - width = width // vae_scale_factor + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) @@ -534,9 +538,10 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = int(height) // self.vae_scale_factor - width = int(width) // self.vae_scale_factor - + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 766f9864839e..3fcf6ace8a79 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -211,9 +211,11 @@ def __init__( self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 ) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible + # by the patch size. So the vae scale factor is multiplied by the patch size to account for this + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2) self.mask_processor = VaeImageProcessor( - vae_scale_factor=self.vae_scale_factor, + vae_scale_factor=self.vae_scale_factor * 2, vae_latent_channels=self.vae.config.latent_channels, do_normalize=False, do_binarize=True, @@ -445,9 +447,9 @@ def check_inputs( if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") - if height % self.vae_scale_factor != 0 or width % self.vae_scale_factor != 0: - raise ValueError( - f"`height` and `width` have to be divisible by {self.vae_scale_factor} but are {height} and {width}." + if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0: + logger.warning( + f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly" ) if callback_on_step_end_tensor_inputs is not None and not all( @@ -526,8 +528,10 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width): def _unpack_latents(latents, height, width, vae_scale_factor): batch_size, num_patches, channels = latents.shape - height = height // vae_scale_factor - width = width // vae_scale_factor + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (vae_scale_factor * 2)) + width = 2 * (int(width) // (vae_scale_factor * 2)) latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2) latents = latents.permute(0, 3, 1, 4, 2, 5) @@ -555,9 +559,10 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - height = int(height) // self.vae_scale_factor - width = int(width) // self.vae_scale_factor - + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) @@ -600,8 +605,10 @@ def prepare_mask_latents( device, generator, ): - height = int(height) // self.vae_scale_factor - width = int(width) // self.vae_scale_factor + # VAE applies 8x compression on images but we must also account for packing which requires + # latent height and width to be divisible by 2. + height = 2 * (int(height) // (self.vae_scale_factor * 2)) + width = 2 * (int(width) // (self.vae_scale_factor * 2)) # resize the mask to latents shape as we concatenate the mask to the latents # we do that before converting to dtype to avoid breaking in case we're using cpu_offload # and half precision @@ -639,7 +646,6 @@ def prepare_mask_latents( # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) - masked_image_latents = self._pack_latents( masked_image_latents, batch_size, diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux.py b/tests/pipelines/controlnet_flux/test_controlnet_flux.py index 89540232f9cf..ee3984dcd3e2 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux.py @@ -181,6 +181,28 @@ def test_controlnet_flux(self): def test_xformers_attention_forwardGenerator_pass(self): pass + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 56)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update( + { + "control_image": randn_tensor( + (1, 3, height, width), + device=torch_device, + dtype=torch.float16, + ) + } + ) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + @slow @require_big_gpu_with_torch_cuda diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py index 9b33d4b46d04..02270d7fbd00 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_img2img.py @@ -14,6 +14,7 @@ from diffusers.utils.testing_utils import ( torch_device, ) +from diffusers.utils.torch_utils import randn_tensor from ..test_pipelines_common import ( PipelineTesterMixin, @@ -218,3 +219,31 @@ def test_fused_qkv_projections(self): assert np.allclose( original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 56)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + inputs.update( + { + "control_image": randn_tensor( + (1, 3, height, width), + device=torch_device, + dtype=torch.float16, + ), + "image": randn_tensor( + (1, 3, height, width), + device=torch_device, + dtype=torch.float16, + ), + "height": height, + "width": width, + } + ) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) diff --git a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py index d66eaaf6a76f..94d97e9962b7 100644 --- a/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py +++ b/tests/pipelines/controlnet_flux/test_controlnet_flux_inpaint.py @@ -23,7 +23,9 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, floats_tensor, + torch_device, ) +from diffusers.utils.torch_utils import randn_tensor from ..test_pipelines_common import PipelineTesterMixin @@ -192,3 +194,33 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 56)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update( + { + "control_image": randn_tensor( + (1, 3, height, width), + device=torch_device, + dtype=torch.float16, + ), + "image": randn_tensor( + (1, 3, height, width), + device=torch_device, + dtype=torch.float16, + ), + "mask_image": torch.ones((1, 1, height, width)).to(torch_device), + "height": height, + "width": width, + } + ) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) diff --git a/tests/pipelines/flux/test_pipeline_flux.py b/tests/pipelines/flux/test_pipeline_flux.py index 3ccf3f80ba3c..df9021ee0adb 100644 --- a/tests/pipelines/flux/test_pipeline_flux.py +++ b/tests/pipelines/flux/test_pipeline_flux.py @@ -191,6 +191,20 @@ def test_fused_qkv_projections(self): original_image_slice, image_slice_disabled, atol=1e-2, rtol=1e-2 ), "Original outputs should match when fused QKV projections are disabled." + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) + @slow @require_big_gpu_with_torch_cuda diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py index a038b1725812..a1336fabdb89 100644 --- a/tests/pipelines/flux/test_pipeline_flux_img2img.py +++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py @@ -147,3 +147,17 @@ def test_flux_prompt_embeds(self): max_diff = np.abs(output_with_prompt - output_with_embeds).max() assert max_diff < 1e-4 + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width) diff --git a/tests/pipelines/flux/test_pipeline_flux_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_inpaint.py index ac2eb1fa261b..3e68d39004b6 100644 --- a/tests/pipelines/flux/test_pipeline_flux_inpaint.py +++ b/tests/pipelines/flux/test_pipeline_flux_inpaint.py @@ -149,3 +149,17 @@ def test_flux_inpaint_prompt_embeds(self): max_diff = np.abs(output_with_prompt - output_with_embeds).max() assert max_diff < 1e-4 + + def test_flux_image_output_shape(self): + pipe = self.pipeline_class(**self.get_dummy_components()).to(torch_device) + inputs = self.get_dummy_inputs(torch_device) + + height_width_pairs = [(32, 32), (72, 57)] + for height, width in height_width_pairs: + expected_height = height - height % (pipe.vae_scale_factor * 2) + expected_width = width - width % (pipe.vae_scale_factor * 2) + + inputs.update({"height": height, "width": width}) + image = pipe(**inputs).images[0] + output_height, output_width, _ = image.shape + assert (output_height, output_width) == (expected_height, expected_width)