Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert fmaf to a generic implementation #499

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 2 additions & 94 deletions src/math/fmaf.rs
Original file line number Diff line number Diff line change
@@ -1,103 +1,11 @@
/* origin: FreeBSD /usr/src/lib/msun/src/s_fmaf.c */
/*-
* Copyright (c) 2005-2011 David Schultz <[email protected]>
* All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions
* are met:
* 1. Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* 2. Redistributions in binary form must reproduce the above copyright
* notice, this list of conditions and the following disclaimer in the
* documentation and/or other materials provided with the distribution.
*
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
* ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
* SUCH DAMAGE.
*/

use core::f32;
use core::ptr::read_volatile;

use super::fenv::{
FE_INEXACT, FE_TONEAREST, FE_UNDERFLOW, feclearexcept, fegetround, feraiseexcept, fetestexcept,
};

/*
* Fused multiply-add: Compute x * y + z with a single rounding error.
*
* A double has more than twice as much precision than a float, so
* direct double-precision arithmetic suffices, except where double
* rounding occurs.
*/

/// Floating multiply add (f32)
///
/// Computes `(x*y)+z`, rounded as one ternary operation:
/// Computes the value (as if) to infinite precision and rounds once to the result format,
/// according to the rounding mode characterized by the value of FLT_ROUNDS.
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
pub fn fmaf(x: f32, y: f32, mut z: f32) -> f32 {
let xy: f64;
let mut result: f64;
let mut ui: u64;
let e: i32;

xy = x as f64 * y as f64;
result = xy + z as f64;
ui = result.to_bits();
e = (ui >> 52) as i32 & 0x7ff;
/* Common case: The double precision result is fine. */
if (
/* not a halfway case */
ui & 0x1fffffff) != 0x10000000 ||
/* NaN */
e == 0x7ff ||
/* exact */
(result - xy == z as f64 && result - z as f64 == xy) ||
/* not round-to-nearest */
fegetround() != FE_TONEAREST
{
/*
underflow may not be raised correctly, example:
fmaf(0x1p-120f, 0x1p-120f, 0x1p-149f)
*/
if ((0x3ff - 149)..(0x3ff - 126)).contains(&e) && fetestexcept(FE_INEXACT) != 0 {
feclearexcept(FE_INEXACT);
// prevent `xy + vz` from being CSE'd with `xy + z` above
let vz: f32 = unsafe { read_volatile(&z) };
result = xy + vz as f64;
if fetestexcept(FE_INEXACT) != 0 {
feraiseexcept(FE_UNDERFLOW);
} else {
feraiseexcept(FE_INEXACT);
}
}
z = result as f32;
return z;
}

/*
* If result is inexact, and exactly halfway between two float values,
* we need to adjust the low-order bit in the direction of the error.
*/
let neg = ui >> 63 != 0;
let err = if neg == (z as f64 > xy) { xy - result + z as f64 } else { z as f64 - result + xy };
if neg == (err < 0.0) {
ui += 1;
} else {
ui -= 1;
}
f64::from_bits(ui) as f32
pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
super::generic::fma_wide(x, y, z)
}

#[cfg(test)]
Expand Down
67 changes: 65 additions & 2 deletions src/math/generic/fma.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
/* SPDX-License-Identifier: MIT */
/* origin: musl src/math/fma.c. Ported to generic Rust algorithm in 2025, TG. */
/* origin: musl src/math/{fma,fmaf}.c. Ported to generic Rust algorithm in 2025, TG. */

use core::{f32, f64};

use super::super::fenv::{
FE_INEXACT, FE_TONEAREST, FE_UNDERFLOW, feclearexcept, fegetround, feraiseexcept, fetestexcept,
};
use super::super::support::{DInt, HInt, IntTy};
use super::super::{CastFrom, CastInto, Float, Int, MinInt};
use super::super::{CastFrom, CastInto, DFloat, Float, HFloat, Int, MinInt};

/// Fused multiply-add that works when there is not a larger float size available. Currently this
/// is still specialized only for `f64`. Computes `(x * y) + z`.
Expand Down Expand Up @@ -212,6 +215,66 @@ where
super::scalbn(r, e)
}

/// Fma implementation when a hardware-backed larger float type is available. For `f32` and `f64`,
/// `f64` has enough precision to represent the `f32` in its entirety, except for double rounding.
pub fn fma_wide<F, B>(x: F, y: F, z: F) -> F
where
F: Float + HFloat<D = B>,
B: Float + DFloat<H = F>,
B::Int: CastInto<i32>,
i32: CastFrom<i32>,
{
let one = IntTy::<B>::ONE;

let xy: B = x.widen() * y.widen();
let mut result: B = xy + z.widen();
let mut ui: B::Int = result.to_bits();
let re = result.exp();
let zb: B = z.widen();

let prec_diff = B::SIG_BITS - F::SIG_BITS;
let excess_prec = ui & ((one << prec_diff) - one);
let halfway = one << (prec_diff - 1);

// Common case: the larger precision is fine if...
// This is not a halfway case
if excess_prec != halfway
// Or the result is NaN
|| re == B::EXP_SAT
// Or the result is exact
|| (result - xy == zb && result - zb == xy)
// Or the mode is something other than round to nearest
|| fegetround() != FE_TONEAREST
{
let min_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN_SUBNORM) as u32;
let max_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN) as u32;

if (min_inexact_exp..max_inexact_exp).contains(&re) && fetestexcept(FE_INEXACT) != 0 {
feclearexcept(FE_INEXACT);
// prevent `xy + vz` from being CSE'd with `xy + z` above
let vz: F = force_eval!(z);
result = xy + vz.widen();
if fetestexcept(FE_INEXACT) != 0 {
feraiseexcept(FE_UNDERFLOW);
} else {
feraiseexcept(FE_INEXACT);
}
}

return result.narrow();
}

let neg = ui >> (B::BITS - 1) != IntTy::<B>::ZERO;
let err = if neg == (zb > xy) { xy - result + zb } else { zb - result + xy };
if neg == (err < B::ZERO) {
ui += one;
} else {
ui -= one;
}

B::from_bits(ui).narrow()
}

/// Representation of `F` that has handled subnormals.
#[derive(Clone, Copy, Debug)]
struct Norm<F: Float> {
Expand Down
2 changes: 1 addition & 1 deletion src/math/generic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub use copysign::copysign;
pub use fabs::fabs;
pub use fdim::fdim;
pub use floor::floor;
pub use fma::fma;
pub use fma::{fma, fma_wide};
pub use fmax::fmax;
pub use fmin::fmin;
pub use fmod::fmod;
Expand Down
2 changes: 1 addition & 1 deletion src/math/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ use self::rem_pio2::rem_pio2;
use self::rem_pio2_large::rem_pio2_large;
use self::rem_pio2f::rem_pio2f;
#[allow(unused_imports)]
use self::support::{CastFrom, CastInto, DInt, Float, HInt, Int, IntTy, MinInt};
use self::support::{CastFrom, CastInto, DFloat, DInt, Float, HFloat, HInt, Int, IntTy, MinInt};

// Public modules
mod acos;
Expand Down
58 changes: 58 additions & 0 deletions src/math/support/float_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,64 @@ pub const fn f64_from_bits(bits: u64) -> f64 {
unsafe { mem::transmute::<u64, f64>(bits) }
}

/// Trait for floats twice the bit width of another integer.
pub trait DFloat: Float {
/// Float that is half the bit width of the floatthis trait is implemented for.
type H: HFloat<D = Self>;

/// Narrow the float type.
fn narrow(self) -> Self::H;
}

/// Trait for floats half the bit width of another float.
pub trait HFloat: Float {
/// Float that is double the bit width of the float this trait is implemented for.
type D: DFloat<H = Self>;

/// Widen the float type.
fn widen(self) -> Self::D;
}

macro_rules! impl_d_float {
($($X:ident $D:ident),*) => {
$(
impl DFloat for $D {
type H = $X;

fn narrow(self) -> Self::H {
self as $X
}
}
)*
};
}

macro_rules! impl_h_float {
($($H:ident $X:ident),*) => {
$(
impl HFloat for $H {
type D = $X;

fn widen(self) -> Self::D {
self as $X
}
}
)*
};
}

impl_d_float!(f32 f64);
#[cfg(f16_enabled)]
impl_d_float!(f16 f32);
#[cfg(f128_enabled)]
impl_d_float!(f64 f128);

impl_h_float!(f32 f64);
#[cfg(f16_enabled)]
impl_h_float!(f16 f32);
#[cfg(f128_enabled)]
impl_h_float!(f64 f128);

#[cfg(test)]
mod tests {
use super::*;
Expand Down
3 changes: 2 additions & 1 deletion src/math/support/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ mod float_traits;
pub mod hex_float;
mod int_traits;

pub use float_traits::{Float, IntTy};
#[allow(unused_imports)]
pub use float_traits::{DFloat, Float, HFloat, IntTy};
pub(crate) use float_traits::{f32_from_bits, f64_from_bits};
#[cfg(f16_enabled)]
#[allow(unused_imports)]
Expand Down