diff --git a/src/diffusers/modular_pipelines/modular_pipeline.py b/src/diffusers/modular_pipelines/modular_pipeline.py index 8a05cce209c5..bae6ec89b38a 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline.py +++ b/src/diffusers/modular_pipelines/modular_pipeline.py @@ -1409,7 +1409,7 @@ def set_progress_bar_config(self, **kwargs): # YiYi TODO: # 1. look into the serialization of modular_model_index.json, make sure the items are properly ordered like model_index.json (currently a mess) # 2. do we need ConfigSpec? the are basically just key/val kwargs -# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_default_components(), load_components() +# 3. imnprove docstring and potentially add validator for methods where we accpet kwargs to be passed to from_pretrained/save_pretrained/load_components() class ModularPipeline(ConfigMixin, PushToHubMixin): """ Base class for all Modular pipelines. @@ -1478,7 +1478,7 @@ def __init__( - Components with default_creation_method="from_config" are created immediately, its specs are not included in config dict and will not be saved in `modular_model_index.json` - Components with default_creation_method="from_pretrained" are set to None and can be loaded later with - `load_default_components()`/`load_components()` + `load_components()` (with or without specific component names) - The pipeline's config dict is populated with component specs (only for from_pretrained components) and config values, which will be saved as `modular_model_index.json` during `save_pretrained` - The pipeline's config dict is also used to store the pipeline blocks's class name, which will be saved as @@ -1548,12 +1548,9 @@ def load_default_components(self, **kwargs): Args: **kwargs: Additional arguments passed to `from_pretrained` method, e.g. torch_dtype, cache_dir, etc. """ - names = [ - name - for name in self._component_specs.keys() - if self._component_specs[name].default_creation_method == "from_pretrained" - ] - self.load_components(names=names, **kwargs) + # Consolidated into load_components - just call it without names parameter + logger.warning("`load_default_components` is deprecated. Please use `load_components()` instead") + self.load_components(**kwargs) @classmethod @validate_hf_hub_args @@ -1682,8 +1679,8 @@ def register_components(self, **kwargs): - non from_pretrained components are created during __init__ and registered as the object itself - Components are updated with the `update_components()` method: e.g. loader.update_components(unet=unet) or loader.update_components(guider=guider_spec) - - (from_pretrained) Components are loaded with the `load_default_components()` method: e.g. - loader.load_default_components(names=["unet"]) + - (from_pretrained) Components are loaded with the `load_components()` method: e.g. + loader.load_components(names=["unet"]) or loader.load_components() to load all default components Args: **kwargs: Keyword arguments where keys are component names and values are component objects. @@ -1995,13 +1992,14 @@ def update_components(self, **kwargs): self.register_to_config(**config_to_register) # YiYi TODO: support map for additional from_pretrained kwargs - # YiYi/Dhruv TODO: consolidate load_components and load_default_components? - def load_components(self, names: Union[List[str], str], **kwargs): + def load_components(self, names: Optional[Union[List[str], str]] = None, **kwargs): """ Load selected components from specs. Args: - names: List of component names to load; by default will not load any components + names: List of component names to load. If None, will load all components with + default_creation_method == "from_pretrained". If provided as a list or string, + will load only the specified components. **kwargs: additional kwargs to be passed to `from_pretrained()`.Can be: - a single value to be applied to all components to be loaded, e.g. torch_dtype=torch.bfloat16 - a dict, e.g. torch_dtype={"unet": torch.bfloat16, "default": torch.float32} @@ -2009,7 +2007,13 @@ def load_components(self, names: Union[List[str], str], **kwargs): `variant`, `revision`, etc. """ - if isinstance(names, str): + if names is None: + names = [ + name + for name in self._component_specs.keys() + if self._component_specs[name].default_creation_method == "from_pretrained" + ] + elif isinstance(names, str): names = [names] elif not isinstance(names, list): raise ValueError(f"Invalid type for names: {type(names)}")