diff --git a/comfy/float.py b/comfy/float.py index 57fd070995e..4a6ae677680 100644 --- a/comfy/float.py +++ b/comfy/float.py @@ -41,9 +41,8 @@ def manual_stochastic_round_to_float8(x, dtype, generator=None): (2.0 ** (exponent - EXPONENT_BIAS)) * (1.0 + abs_x), (2.0 ** (-EXPONENT_BIAS + 1)) * abs_x ) - del abs_x - return sign.to(dtype=dtype) + return sign @@ -57,6 +56,11 @@ def stochastic_rounding(value, dtype, seed=0): if dtype == torch.float8_e4m3fn or dtype == torch.float8_e5m2: generator = torch.Generator(device=value.device) generator.manual_seed(seed) - return manual_stochastic_round_to_float8(value, dtype, generator=generator) + output = torch.empty_like(value, dtype=dtype) + num_slices = max(1, (value.numel() / (4096 * 4096))) + slice_size = max(1, round(value.shape[0] / num_slices)) + for i in range(0, value.shape[0], slice_size): + output[i:i+slice_size].copy_(manual_stochastic_round_to_float8(value[i:i+slice_size], dtype, generator=generator)) + return output return value.to(dtype=dtype)