Skip to content

Commit

Permalink
add foreach to from_pretrained
Browse files Browse the repository at this point in the history
  • Loading branch information
drhead authored Jun 22, 2024
1 parent e9ed284 commit 8bdca08
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,11 +350,11 @@ def __init__(
self.model_config = model_config

@classmethod
def from_pretrained(cls, path, model_cls) -> "EMAModel":
def from_pretrained(cls, path, model_cls, foreach) -> "EMAModel":
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
model = model_cls.from_pretrained(path)

ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)

ema_model.load_state_dict(ema_kwargs)
return ema_model
Expand Down

0 comments on commit 8bdca08

Please sign in to comment.