diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 58dcf8c81caa..8e1a0a6f208b 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -350,11 +350,11 @@ def __init__( self.model_config = model_config @classmethod - def from_pretrained(cls, path, model_cls) -> "EMAModel": + def from_pretrained(cls, path, model_cls, foreach) -> "EMAModel": _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True) model = model_cls.from_pretrained(path) - ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config) + ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach) ema_model.load_state_dict(ema_kwargs) return ema_model