Skip to content

Commit

Permalink
Remove mps workaround for fp16 GELU, which is now supported natively (h…
Browse files Browse the repository at this point in the history
…uggingface#10133)

* Remove mps workaround for fp16 GELU, which is now supported natively

---------

Co-authored-by: hlky <[email protected]>
  • Loading branch information
skotapati and hlky authored Dec 13, 2024
1 parent bdbaea8 commit ec9bfa9
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions src/diffusers/models/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from torch import nn

from ..utils import deprecate
from ..utils.import_utils import is_torch_npu_available
from ..utils.import_utils import is_torch_npu_available, is_torch_version


if is_torch_npu_available():
Expand Down Expand Up @@ -79,10 +79,10 @@ def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: b
self.approximate = approximate

def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
return F.gelu(gate, approximate=self.approximate)

def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
Expand All @@ -105,10 +105,10 @@ def __init__(self, dim_in: int, dim_out: int, bias: bool = True):
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias)

def gelu(self, gate: torch.Tensor) -> torch.Tensor:
if gate.device.type != "mps":
return F.gelu(gate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
if gate.device.type == "mps" and is_torch_version("<", "2.0.0"):
# fp16 gelu not supported on mps before torch 2.0
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
return F.gelu(gate)

def forward(self, hidden_states, *args, **kwargs):
if len(args) > 0 or kwargs.get("scale", None) is not None:
Expand Down

0 comments on commit ec9bfa9

Please sign in to comment.