Skip to content

Commit

Permalink
[Sana-4K] (huggingface#10537)
Browse files Browse the repository at this point in the history
* [Sana 4K]
add 4K support for Sana

* [Sana-4K] fix SanaPAGPipeline

* add VAE automatically tiling function;

* set clean_caption to False;

* add warnings for VAE OOM.

* style

---------

Co-authored-by: yiyixuxu <[email protected]>
  • Loading branch information
lawrence-cj and yiyixuxu authored Jan 14, 2025
1 parent 6b72784 commit 3d70777
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 4 deletions.
17 changes: 14 additions & 3 deletions src/diffusers/pipelines/pag/pipeline_pag_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import inspect
import re
import urllib.parse as ul
import warnings
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
Expand All @@ -41,6 +42,7 @@
ASPECT_RATIO_1024_BIN,
)
from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
from ..sana.pipeline_sana import ASPECT_RATIO_4096_BIN
from .pag_utils import PAGMixin


Expand Down Expand Up @@ -639,7 +641,7 @@ def __call__(
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
clean_caption: bool = True,
clean_caption: bool = False,
use_resolution_binning: bool = True,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
Expand Down Expand Up @@ -755,7 +757,9 @@ def __call__(
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

if use_resolution_binning:
if self.transformer.config.sample_size == 64:
if self.transformer.config.sample_size == 128:
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
elif self.transformer.config.sample_size == 64:
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
elif self.transformer.config.sample_size == 32:
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
Expand Down Expand Up @@ -912,7 +916,14 @@ def __call__(
image = latents
else:
latents = latents.to(self.vae.dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
)
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)

Expand Down
10 changes: 9 additions & 1 deletion src/diffusers/pipelines/sana/pipeline_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import inspect
import re
import urllib.parse as ul
import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -953,7 +954,14 @@ def __call__(
image = latents
else:
latents = latents.to(self.vae.dtype)
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
try:
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
except torch.cuda.OutOfMemoryError as e:
warnings.warn(
f"{e}. \n"
f"Try to use VAE tiling for large images. For example: \n"
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
)
if use_resolution_binning:
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)

Expand Down

0 comments on commit 3d70777

Please sign in to comment.