Skip to content

Commit

Permalink
add skip_layers argument to SD3 transformer model class (huggingface#…
Browse files Browse the repository at this point in the history
…9880)

* add skip_layers argument to SD3 transformer model class

* add unit test for skip_layers in stable diffusion 3

* sd3: pipeline should support skip layer guidance

* up

---------

Co-authored-by: bghira <[email protected]>
Co-authored-by: yiyixuxu <[email protected]>
  • Loading branch information
3 people authored Nov 19, 2024
1 parent cc7d88f commit 99c0483
Show file tree
Hide file tree
Showing 3 changed files with 82 additions and 6 deletions.
15 changes: 10 additions & 5 deletions src/diffusers/models/transformers/transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,7 @@ def forward(
block_controlnet_hidden_states: List = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
skip_layers: Optional[List[int]] = None,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
"""
The [`SD3Transformer2DModel`] forward method.
Expand All @@ -279,9 +280,9 @@ def forward(
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
from the embeddings of input conditions.
timestep ( `torch.LongTensor`):
timestep (`torch.LongTensor`):
Used to indicate denoising step.
block_controlnet_hidden_states: (`list` of `torch.Tensor`):
block_controlnet_hidden_states (`list` of `torch.Tensor`):
A list of tensors that if specified are added to the residuals of transformer blocks.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
Expand All @@ -290,6 +291,8 @@ def forward(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
tuple.
skip_layers (`list` of `int`, *optional*):
A list of layer indices to skip during the forward pass.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
Expand Down Expand Up @@ -317,7 +320,10 @@ def forward(
encoder_hidden_states = self.context_embedder(encoder_hidden_states)

for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
# Skip specified layers
is_skip = True if skip_layers is not None and index_block in skip_layers else False

if torch.is_grad_enabled() and self.gradient_checkpointing and not is_skip:

def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
Expand All @@ -336,8 +342,7 @@ def custom_forward(*inputs):
temb,
**ckpt_kwargs,
)

else:
elif not is_skip:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,10 @@ def prepare_latents(
def guidance_scale(self):
return self._guidance_scale

@property
def skip_guidance_layers(self):
return self._skip_guidance_layers

@property
def clip_skip(self):
return self._clip_skip
Expand Down Expand Up @@ -694,6 +698,10 @@ def __call__(
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
max_sequence_length: int = 256,
skip_guidance_layers: List[int] = None,
skip_layer_guidance_scale: int = 2.8,
skip_layer_guidance_stop: int = 0.2,
skip_layer_guidance_start: int = 0.01,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -778,6 +786,22 @@ def __call__(
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
`._callback_tensor_inputs` attribute of your pipeline class.
max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
skip_guidance_layers (`List[int]`, *optional*):
A list of integers that specify layers to skip during guidance. If not provided, all layers will be
used for guidance. If provided, the guidance will only be applied to the layers specified in the list.
Recommended value by StabiltyAI for Stable Diffusion 3.5 Medium is [7, 8, 9].
skip_layer_guidance_scale (`int`, *optional*): The scale of the guidance for the layers specified in
`skip_guidance_layers`. The guidance will be applied to the layers specified in `skip_guidance_layers`
with a scale of `skip_layer_guidance_scale`. The guidance will be applied to the rest of the layers
with a scale of `1`.
skip_layer_guidance_stop (`int`, *optional*): The step at which the guidance for the layers specified in
`skip_guidance_layers` will stop. The guidance will be applied to the layers specified in
`skip_guidance_layers` until the fraction specified in `skip_layer_guidance_stop`. Recommended value by
StabiltyAI for Stable Diffusion 3.5 Medium is 0.2.
skip_layer_guidance_start (`int`, *optional*): The step at which the guidance for the layers specified in
`skip_guidance_layers` will start. The guidance will be applied to the layers specified in
`skip_guidance_layers` from the fraction specified in `skip_layer_guidance_start`. Recommended value by
StabiltyAI for Stable Diffusion 3.5 Medium is 0.01.
Examples:
Expand Down Expand Up @@ -809,6 +833,7 @@ def __call__(
)

self._guidance_scale = guidance_scale
self._skip_layer_guidance_scale = skip_layer_guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False
Expand Down Expand Up @@ -851,6 +876,9 @@ def __call__(
)

if self.do_classifier_free_guidance:
if skip_guidance_layers is not None:
original_prompt_embeds = prompt_embeds
original_pooled_prompt_embeds = pooled_prompt_embeds
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)

Expand Down Expand Up @@ -879,7 +907,11 @@ def __call__(
continue

# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
latent_model_input = (
torch.cat([latents] * 2)
if self.do_classifier_free_guidance and skip_guidance_layers is None
else latents
)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timestep = t.expand(latent_model_input.shape[0])

Expand All @@ -896,6 +928,25 @@ def __call__(
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
should_skip_layers = (
True
if i > num_inference_steps * skip_layer_guidance_start
and i < num_inference_steps * skip_layer_guidance_stop
else False
)
if skip_guidance_layers is not None and should_skip_layers:
noise_pred_skip_layers = self.transformer(
hidden_states=latent_model_input,
timestep=timestep,
encoder_hidden_states=original_prompt_embeds,
pooled_projections=original_pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
skip_layers=skip_guidance_layers,
)[0]
noise_pred = (
noise_pred + (noise_pred_text - noise_pred_skip_layers) * self._skip_layer_guidance_scale
)

# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
Expand Down
20 changes: 20 additions & 0 deletions tests/models/transformers/test_models_transformer_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,23 @@ def test_set_attn_processor_for_determinism(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SD3Transformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

def test_skip_layers(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict).to(torch_device)

# Forward pass without skipping layers
output_full = model(**inputs_dict).sample

# Forward pass with skipping layers 0 (since there's only one layer in this test setup)
inputs_dict_with_skip = inputs_dict.copy()
inputs_dict_with_skip["skip_layers"] = [0]
output_skip = model(**inputs_dict_with_skip).sample

# Check that the outputs are different
self.assertFalse(
torch.allclose(output_full, output_skip, atol=1e-5), "Outputs should differ when layers are skipped"
)

# Check that the outputs have the same shape
self.assertEqual(output_full.shape, output_skip.shape, "Outputs should have the same shape")

0 comments on commit 99c0483

Please sign in to comment.