Skip to content

Commit

Permalink
[Community pipeline] Marigold depth estimation update -- align with m…
Browse files Browse the repository at this point in the history
…arigold v0.1.5 (huggingface#7524)

* add resample option; check denoise_step; update ckpt path

* Add seeding in pipeline to increase reproducibility

* fix typo

* fix typo
  • Loading branch information
markkua authored Mar 30, 2024
1 parent ca61287 commit c2e8786
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 21 deletions.
24 changes: 22 additions & 2 deletions examples/community/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,25 @@ This depth estimation pipeline processes a single input image through multiple d

```python
import numpy as np
import torch
from PIL import Image
from diffusers import DiffusionPipeline
from diffusers.utils import load_image

# Original DDIM version (higher quality)
pipe = DiffusionPipeline.from_pretrained(
"prs-eth/marigold-v1-0",
custom_pipeline="marigold_depth_estimation"
# torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
# variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint
)

# (New) LCM version (faster speed)
pipe = DiffusionPipeline.from_pretrained(
"Bingxin/Marigold",
"prs-eth/marigold-lcm-v1-0",
custom_pipeline="marigold_depth_estimation"
# torch_dtype=torch.float16, # (optional) Run with half-precision (16-bit float).
# variant="fp16", # (optional) Use with `torch_dtype=torch.float16`, to directly load fp16 checkpoint
)

pipe.to("cuda")
Expand All @@ -101,12 +112,21 @@ img_path_or_url = "https://share.phys.ethz.ch/~pf/bingkedata/marigold/pipeline_e
image: Image.Image = load_image(img_path_or_url)

pipeline_output = pipe(
image, # Input image.
image, # Input image.
# ----- recommended setting for DDIM version -----
# denoising_steps=10, # (optional) Number of denoising steps of each inference pass. Default: 10.
# ensemble_size=10, # (optional) Number of inference passes in the ensemble. Default: 10.
# ------------------------------------------------

# ----- recommended setting for LCM version ------
# denoising_steps=4,
# ensemble_size=5,
# -------------------------------------------------

# processing_res=768, # (optional) Maximum resolution of processing. If set to 0: will not resize at all. Defaults to 768.
# match_input_res=True, # (optional) Resize depth prediction to match input resolution.
# batch_size=0, # (optional) Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. Defaults to 0.
# seed=2024, # (optional) Random seed can be set to ensure additional reproducibility. Default: None (unseeded). Note: forcing --batch_size 1 helps to increase reproducibility. To ensure full reproducibility, deterministic mode needs to be used.
# color_map="Spectral", # (optional) Colormap used to colorize the depth map. Defaults to "Spectral". Set to `None` to skip colormap generation.
# show_progress_bar=True, # (optional) If true, will show progress bars of the inference progress.
)
Expand Down
106 changes: 87 additions & 19 deletions examples/community/marigold_depth_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
# --------------------------------------------------------------------------


import logging
import math
from typing import Dict, Union

import matplotlib
import numpy as np
import torch
from PIL import Image
from PIL.Image import Resampling
from scipy.optimize import minimize
from torch.utils.data import DataLoader, TensorDataset
from tqdm.auto import tqdm
Expand All @@ -34,13 +36,14 @@
AutoencoderKL,
DDIMScheduler,
DiffusionPipeline,
LCMScheduler,
UNet2DConditionModel,
)
from diffusers.utils import BaseOutput, check_min_version


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.28.0.dev0")
check_min_version("0.25.0")


class MarigoldDepthOutput(BaseOutput):
Expand All @@ -61,6 +64,19 @@ class MarigoldDepthOutput(BaseOutput):
uncertainty: Union[None, np.ndarray]


def get_pil_resample_method(method_str: str) -> Resampling:
resample_method_dic = {
"bilinear": Resampling.BILINEAR,
"bicubic": Resampling.BICUBIC,
"nearest": Resampling.NEAREST,
}
resample_method = resample_method_dic.get(method_str, None)
if resample_method is None:
raise ValueError(f"Unknown resampling method: {resample_method}")
else:
return resample_method


class MarigoldPipeline(DiffusionPipeline):
"""
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io.
Expand Down Expand Up @@ -113,7 +129,9 @@ def __call__(
ensemble_size: int = 10,
processing_res: int = 768,
match_input_res: bool = True,
resample_method: str = "bilinear",
batch_size: int = 0,
seed: Union[int, None] = None,
color_map: str = "Spectral",
show_progress_bar: bool = True,
ensemble_kwargs: Dict = None,
Expand All @@ -129,14 +147,18 @@ def __call__(
If set to 0: will not resize at all.
match_input_res (`bool`, *optional*, defaults to `True`):
Resize depth prediction to match input resolution.
Only valid if `limit_input_res` is not None.
Only valid if `processing_res` > 0.
resample_method: (`str`, *optional*, defaults to `bilinear`):
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`.
denoising_steps (`int`, *optional*, defaults to `10`):
Number of diffusion denoising steps (DDIM) during inference.
ensemble_size (`int`, *optional*, defaults to `10`):
Number of predictions to be ensembled.
batch_size (`int`, *optional*, defaults to `0`):
Inference batch size, no bigger than `num_ensemble`.
If set to 0, the script will automatically decide the proper batch size.
seed (`int`, *optional*, defaults to `None`)
Reproducibility seed.
show_progress_bar (`bool`, *optional*, defaults to `True`):
Display a progress bar of diffusion denoising.
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation):
Expand All @@ -146,8 +168,7 @@ def __call__(
Returns:
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including:
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1]
- **depth_colored** (`None` or `PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and
values in [0, 1]. None if `color_map` is `None`
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None`
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation)
coming from ensembling. None if `ensemble_size = 1`
"""
Expand All @@ -158,13 +179,21 @@ def __call__(
if not match_input_res:
assert processing_res is not None, "Value error: `resize_output_back` is only valid with "
assert processing_res >= 0
assert denoising_steps >= 1
assert ensemble_size >= 1

# Check if denoising step is reasonable
self._check_inference_step(denoising_steps)

resample_method: Resampling = get_pil_resample_method(resample_method)

# ----------------- Image Preprocess -----------------
# Resize image
if processing_res > 0:
input_image = self.resize_max_res(input_image, max_edge_resolution=processing_res)
input_image = self.resize_max_res(
input_image,
max_edge_resolution=processing_res,
resample_method=resample_method,
)
# Convert the image to RGB, to 1.remove the alpha channel 2.convert B&W to 3-channel
input_image = input_image.convert("RGB")
image = np.asarray(input_image)
Expand Down Expand Up @@ -203,9 +232,10 @@ def __call__(
rgb_in=batched_img,
num_inference_steps=denoising_steps,
show_pbar=show_progress_bar,
seed=seed,
)
depth_pred_ls.append(depth_pred_raw.detach().clone())
depth_preds = torch.concat(depth_pred_ls, axis=0).squeeze()
depth_pred_ls.append(depth_pred_raw.detach())
depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze()
torch.cuda.empty_cache() # clear vram cache for ensembling

# ----------------- Test-time ensembling -----------------
Expand All @@ -227,7 +257,7 @@ def __call__(
# Resize back to original resolution
if match_input_res:
pred_img = Image.fromarray(depth_pred)
pred_img = pred_img.resize(input_size)
pred_img = pred_img.resize(input_size, resample=resample_method)
depth_pred = np.asarray(pred_img)

# Clip output range
Expand All @@ -243,12 +273,32 @@ def __call__(
depth_colored_img = Image.fromarray(depth_colored_hwc)
else:
depth_colored_img = None

return MarigoldDepthOutput(
depth_np=depth_pred,
depth_colored=depth_colored_img,
uncertainty=pred_uncert,
)

def _check_inference_step(self, n_step: int):
"""
Check if denoising step is reasonable
Args:
n_step (`int`): denoising steps
"""
assert n_step >= 1

if isinstance(self.scheduler, DDIMScheduler):
if n_step < 10:
logging.warning(
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference."
)
elif isinstance(self.scheduler, LCMScheduler):
if not 1 <= n_step <= 4:
logging.warning(f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps.")
else:
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}")

def _encode_empty_text(self):
"""
Encode text embedding for empty prompt.
Expand All @@ -265,7 +315,13 @@ def _encode_empty_text(self):
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype)

@torch.no_grad()
def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar: bool) -> torch.Tensor:
def single_infer(
self,
rgb_in: torch.Tensor,
num_inference_steps: int,
seed: Union[int, None],
show_pbar: bool,
) -> torch.Tensor:
"""
Perform an individual depth prediction without ensembling.
Expand All @@ -286,10 +342,20 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
timesteps = self.scheduler.timesteps # [T]

# Encode image
rgb_latent = self._encode_rgb(rgb_in)
rgb_latent = self.encode_rgb(rgb_in)

# Initial depth map (noise)
depth_latent = torch.randn(rgb_latent.shape, device=device, dtype=self.dtype) # [B, 4, h, w]
if seed is None:
rand_num_generator = None
else:
rand_num_generator = torch.Generator(device=device)
rand_num_generator.manual_seed(seed)
depth_latent = torch.randn(
rgb_latent.shape,
device=device,
dtype=self.dtype,
generator=rand_num_generator,
) # [B, 4, h, w]

# Batched empty text embedding
if self.empty_text_embed is None:
Expand All @@ -314,9 +380,9 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar
noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample # [B, 4, h, w]

# compute the previous noisy sample x_t -> x_t-1
depth_latent = self.scheduler.step(noise_pred, t, depth_latent).prev_sample
torch.cuda.empty_cache()
depth = self._decode_depth(depth_latent)
depth_latent = self.scheduler.step(noise_pred, t, depth_latent, generator=rand_num_generator).prev_sample

depth = self.decode_depth(depth_latent)

# clip prediction
depth = torch.clip(depth, -1.0, 1.0)
Expand All @@ -325,7 +391,7 @@ def single_infer(self, rgb_in: torch.Tensor, num_inference_steps: int, show_pbar

return depth

def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
"""
Encode RGB image into latent.
Expand All @@ -344,7 +410,7 @@ def _encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor:
rgb_latent = mean * self.rgb_latent_scale_factor
return rgb_latent

def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
"""
Decode depth latent into depth map.
Expand All @@ -365,7 +431,7 @@ def _decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor:
return depth_mean

@staticmethod
def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image:
"""
Resize image to limit maximum edge length while keeping aspect ratio.
Expand All @@ -374,6 +440,8 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
Image to be resized.
max_edge_resolution (`int`):
Maximum edge length (pixel).
resample_method (`PIL.Image.Resampling`):
Resampling method used to resize images.
Returns:
`Image.Image`: Resized image.
Expand All @@ -384,7 +452,7 @@ def resize_max_res(img: Image.Image, max_edge_resolution: int) -> Image.Image:
new_width = int(original_width * downscale_factor)
new_height = int(original_height * downscale_factor)

resized_img = img.resize((new_width, new_height))
resized_img = img.resize((new_width, new_height), resample=resample_method)
return resized_img

@staticmethod
Expand Down

0 comments on commit c2e8786

Please sign in to comment.