Skip to content

Commit

Permalink
Allow more arguments to be passed to convert_from_ckpt (huggingface#7222
Browse files Browse the repository at this point in the history
)

Allow safety and feature extractor arguments to be passed to convert_from_ckpt

Allows management of safety checker and feature extractor
from outside of the convert ckpt class.

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
w4ffl35 and sayakpaul authored Apr 8, 2024
1 parent 56a7608 commit 7e39516
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,6 +1153,8 @@ def download_from_original_stable_diffusion_ckpt(
controlnet: Optional[bool] = None,
adapter: Optional[bool] = None,
load_safety_checker: bool = True,
safety_checker: Optional[StableDiffusionSafetyChecker] = None,
feature_extractor: Optional[AutoFeatureExtractor] = None,
pipeline_class: DiffusionPipeline = None,
local_files_only=False,
vae_path=None,
Expand Down Expand Up @@ -1205,6 +1207,12 @@ def download_from_original_stable_diffusion_ckpt(
If `checkpoint_path` is in `safetensors` format, load checkpoint with safetensors instead of PyTorch.
load_safety_checker (`bool`, *optional*, defaults to `True`):
Whether to load the safety checker or not. Defaults to `True`.
safety_checker (`StableDiffusionSafetyChecker`, *optional*, defaults to `None`):
Safety checker to use. If this parameter is `None`, the function will load a new instance of
[StableDiffusionSafetyChecker] by itself, if needed.
feature_extractor (`AutoFeatureExtractor`, *optional*, defaults to `None`):
Feature extractor to use. If this parameter is `None`, the function will load a new instance of
[AutoFeatureExtractor] by itself, if needed.
pipeline_class (`str`, *optional*, defaults to `None`):
The pipeline class to use. Pass `None` to determine automatically.
local_files_only (`bool`, *optional*, defaults to `False`):
Expand Down Expand Up @@ -1530,8 +1538,8 @@ def download_from_original_stable_diffusion_ckpt(
unet=unet,
scheduler=scheduler,
controlnet=controlnet,
safety_checker=None,
feature_extractor=None,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False
Expand All @@ -1551,8 +1559,8 @@ def download_from_original_stable_diffusion_ckpt(
unet=unet,
scheduler=scheduler,
low_res_scheduler=low_res_scheduler,
safety_checker=None,
feature_extractor=None,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)

else:
Expand All @@ -1562,8 +1570,8 @@ def download_from_original_stable_diffusion_ckpt(
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None,
feature_extractor=None,
safety_checker=safety_checker,
feature_extractor=feature_extractor,
)
if hasattr(pipe, "requires_safety_checker"):
pipe.requires_safety_checker = False
Expand Down Expand Up @@ -1684,9 +1692,6 @@ def download_from_original_stable_diffusion_ckpt(
feature_extractor = AutoFeatureExtractor.from_pretrained(
"CompVis/stable-diffusion-safety-checker", local_files_only=local_files_only
)
else:
safety_checker = None
feature_extractor = None

if controlnet:
pipe = pipeline_class(
Expand Down

0 comments on commit 7e39516

Please sign in to comment.