diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index f4318fc3cd39..c1d4f0b46e15 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -18,7 +18,7 @@ from torch import nn from ..utils import deprecate -from ..utils.import_utils import is_torch_npu_available +from ..utils.import_utils import is_torch_npu_available, is_torch_version if is_torch_npu_available(): @@ -79,10 +79,10 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b self.approximate = approximate def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate, approximate=self.approximate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + if gate.device.type == "mps" and is_torch_version("<", "2.0.0"): + # fp16 gelu not supported on mps before torch 2.0 + return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype) + return F.gelu(gate, approximate=self.approximate) def forward(self, hidden_states): hidden_states = self.proj(hidden_states) @@ -105,10 +105,10 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True): self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) def gelu(self, gate: torch.Tensor) -> torch.Tensor: - if gate.device.type != "mps": - return F.gelu(gate) - # mps: gelu is not implemented for float16 - return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + if gate.device.type == "mps" and is_torch_version("<", "2.0.0"): + # fp16 gelu not supported on mps before torch 2.0 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + return F.gelu(gate) def forward(self, hidden_states, *args, **kwargs): if len(args) > 0 or kwargs.get("scale", None) is not None: