Skip to content

Improve performance scaling of fmod using modular exponentiation #898

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions libm/src/math/generic/fmod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/* SPDX-License-Identifier: MIT OR Apache-2.0 */
use super::super::{CastFrom, Float, Int, MinInt};
use crate::support::{DInt, HInt, Reducer};

#[inline]
pub fn fmod<F: Float>(x: F, y: F) -> F {
Expand Down Expand Up @@ -59,10 +60,102 @@ fn into_sig_exp<F: Float>(mut bits: F::Int) -> (F::Int, u32) {

/// Compute the remainder `(x * 2.pow(e)) % y` without overflow.
fn reduction<I: Int>(mut x: I, e: u32, y: I) -> I {
// FIXME: This is a temporary hack to get around the lack of `u256 / u256`.
// Actually, the algorithm only needs the operation `(x << I::BITS) / y`
// where `x < y`. That is, a division `u256 / u128` where the quotient must
// not overflow `u128` would be sufficient for `f128`.
Comment on lines +63 to +66
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would this be easier with u256? We have an implementation here

but only shifts and widening multiplication are supported currently.

Copy link
Contributor Author

@quaternic quaternic May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If u256 already implemented Div, then the generic code here could just use that, so it would be easy, yes. But implementing that sounds like extra work just to achieve a suboptimal solution.

a division u2N / uN where the quotient must not overflow uN

For context, x86's div-instruction works just like this. Take the double-wide dividend in a pair of registers, divide by a value in a third register, and replace the low and high halves of the dividend with the quotient and remainder respectively. If the quotient would overflow (which is exactly when x.hi() >= y), signal divide error.

So that's the abstraction I'd like to use; something like

unsafe fn unchecked_wide_div_rem<U: HInt>(U::D, U) -> (U, U);

But of course, that would be even more work to implement since it doesn't exist yet, and I don't think other arches have a native operation for it.

Another idea would be to get rid of the integer division altogether, and compute the reciprocal from the original floating point value. I expect that this could have better performance, but it needs more careful analysis.

unsafe {
use core::mem::transmute_copy;
if I::BITS == 64 {
let x = transmute_copy::<I, u64>(&x);
let y = transmute_copy::<I, u64>(&y);
let r = fast_reduction::<f64, u64>(x, e, y);
return transmute_copy::<u64, I>(&r);
}
if I::BITS == 32 {
let x = transmute_copy::<I, u32>(&x);
let y = transmute_copy::<I, u32>(&y);
let r = fast_reduction::<f32, u32>(x, e, y);
return transmute_copy::<u32, I>(&r);
}
#[cfg(f16_enabled)]
if I::BITS == 16 {
let x = transmute_copy::<I, u16>(&x);
let y = transmute_copy::<I, u16>(&y);
let r = fast_reduction::<f16, u16>(x, e, y);
return transmute_copy::<u16, I>(&r);
}
}

x %= y;
for _ in 0..e {
x <<= 1;
x = x.checked_sub(y).unwrap_or(x);
}
x
}

trait SafeShift: Float {
// How many guaranteed leading zeros do the values have?
// A normalized floating point mantissa has `EXP_BITS` guaranteed leading
// zeros (exludes the implicit bit, but includes the now-zeroed sign bit)
// `-1` because we want to shift by either `BASE_SHIFT` or `BASE_SHIFT + 1`
const BASE_SHIFT: u32 = Self::EXP_BITS - 1;
}
impl<F: Float> SafeShift for F {}

fn fast_reduction<F, I>(x: I, e: u32, y: I) -> I
where
F: Float<Int = I>,
I: Int + HInt,
I::D: Int + DInt<H = I>,
{
let _0 = I::ZERO;
let _1 = I::ONE;

if y == _1 {
return _0;
}

if e <= F::BASE_SHIFT {
return (x << e) % y;
}

// Find least depth s.t. `(e >> depth) < I::BITS`
let depth = (I::BITS - 1)
.leading_zeros()
.saturating_sub(e.leading_zeros());

let initial = (e >> depth) - F::BASE_SHIFT;

let max_rem = y.wrapping_sub(_1);
let max_ilog2 = max_rem.ilog2();
let mut pow2 = _1 << max_ilog2.min(initial);
for _ in max_ilog2..initial {
pow2 <<= 1;
pow2 = pow2.checked_sub(y).unwrap_or(pow2);
}

// At each step `k in [depth, ..., 0]`,
// `p` is `(e >> k) - BASE_SHIFT`
// `m` is `(1 << p) % y`
let mut k = depth;
let mut p = initial;
let mut m = Reducer::new(pow2, y);

while k > 0 {
k -= 1;
p = p + p + F::BASE_SHIFT;
if e & (1 << k) != 0 {
m = m.squared_with_shift(F::BASE_SHIFT + 1);
p += 1;
} else {
m = m.squared_with_shift(F::BASE_SHIFT);
};

debug_assert!(p == (e >> k) - F::BASE_SHIFT);
}

// (x << BASE_SHIFT) * (1 << p) == x << e
m.mul_into_div_rem(x << F::BASE_SHIFT).1
}
3 changes: 3 additions & 0 deletions libm/src/math/support/int_traits.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
use core::{cmp, fmt, ops};

mod mod_mul;
pub(crate) use mod_mul::Reducer;

/// Minimal integer implementations needed on all integer types, including wide integers.
pub trait MinInt:
Copy
Expand Down
222 changes: 222 additions & 0 deletions libm/src/math/support/int_traits/mod_mul.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
use super::{DInt, HInt, Int};

/// Barrett reduction using the constant `R == (1 << K) == (1 << U::BITS)`
///
/// For a more detailed description, see
/// <https://en.wikipedia.org/wiki/Barrett_reduction>.
///
/// After constructing as `Reducer::new(b, n)`,
/// has operations to efficiently compute
/// - `(a * b) / n` and `(a * b) % n`
/// - `Reducer::new((a * b * b) % n, n)`, as long as `a * (n - 1) < R`
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub(crate) struct Reducer<U> {
// the multiplying factor `b in 0..n`
num: U,
// the modulus `n in 1..=R/2`
div: U,
// the precomputed quotient, `q = (b << K) / n`
quo: U,
// the remainder of that division, `r = (b << K) % n`,
// (could always be recomputed as `(b << K) - q * n`,
// but it is convenient to save)
rem: U,
}

impl<U> Reducer<U>
where
U: Int + HInt,
U::D: core::ops::Div<Output = U::D>,
U::D: core::ops::Rem<Output = U::D>,
{
/// Requires `num < div <= R/2`, will panic otherwise
#[inline]
pub fn new(num: U, div: U) -> Self {
let _0 = U::ZERO;
let _1 = U::ONE;

assert!(num < div);
assert!(div.wrapping_sub(_1).leading_zeros() >= 1);

let bk = num.widen_hi();
let n = div.widen();
let quo = (bk / n).lo();
let rem = (bk % n).lo();

Self { num, div, quo, rem }
}
}

impl<U> Reducer<U>
where
U: Int + HInt,
U::D: Int,
{
/// Return the unique pair `(quotient, remainder)`
/// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < n`
#[inline]
pub fn mul_into_div_rem(&self, a: U) -> (U, U) {
let (q, mut r) = self.mul_into_unnormalized_div_rem(a);
// The unnormalized remainder is still guaranteed to be less than `2n`, so
// one checked subtraction is sufficient.
(q + U::cast_from(self.fixup(&mut r) as u8), r)
}

#[inline(always)]
pub fn fixup(&self, x: &mut U) -> bool {
x.checked_sub(self.div).map(|r| *x = r).is_some()
}

/// Return some pair `(quotient, remainder)`
/// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < 2n`
#[inline]
pub fn mul_into_unnormalized_div_rem(&self, a: U) -> (U, U) {
// General idea: Estimate the quotient `quotient = t in 0..a` s.t.
// the remainder `ab - tn` is close to zero, so `t ~= ab / n`

// Note: we use `R == 1 << U::BITS`, which means that
// - wrapping arithmetic with `U` is modulo `R`
// - all inputs are less than `R`

// Range analysis:
//
// Using the definition of euclidean division on the two divisions done:
// ```
// bR = qn + r, with 0 <= r < n
// aq = tR + s, with 0 <= s < R
// ```
let (_s, t) = a.widen_mul(self.quo).lo_hi();
// Then
// ```
// (ab - tn)R
// = abR - ntR
// = a(qn + r) - n(aq - s)
// = ar + ns
// ```
#[cfg(debug_assertions)]
{
assert!(t < a || (a == t && t.is_zero()));
let ab_tn = a.widen_mul(self.num) - t.widen_mul(self.div);
let ar_ns = a.widen_mul(self.rem) + _s.widen_mul(self.div);
assert!(ab_tn.hi().is_zero());
assert!(ar_ns.lo().is_zero());
assert!(ab_tn.lo() == ar_ns.hi());
}
// Since `s < R` and `r < n`,
// ```
// 0 <= ns < nR
// 0 <= ar < an
// 0 <= (ab - tn) == (ar + ns)/R < n(1 + a/R)
// ```
// Since `a < R` and we check on construction that `n <= R/2`, the result
// is `0 <= ab - tn < R`, so it can be computed modulo `R`
// even though the intermediate terms generally wrap.
let ab = a.wrapping_mul(self.num);
let tn = t.wrapping_mul(self.div);
(t, ab.wrapping_sub(tn))
}

/// Constructs a new reducer with `b` set to `(ab * b) % n`
///
/// Requires `r * ab == ra * b`, where `r = bR % n`.
#[inline(always)]
fn with_scaled_num_rem(&self, ab: U, ra: U) -> Self {
debug_assert!(ab.widen_mul(self.rem) == ra.widen_mul(self.num));
// The new factor `v = abb mod n`:
let (_, v) = self.mul_into_div_rem(ab);

// `rab = cn + d`, where `0 <= d < n`
let (c, d) = self.mul_into_div_rem(ra);

// We need `abbR = Xn + Y`:
// abbR
// = ab(qn + r)
// = abqn + rab
// = abqn + cn + d
// = (abq + c)n + d

Self {
num: v,
div: self.div,
quo: self.quo.wrapping_mul(ab).wrapping_add(c),
rem: d,
}
}

/// Computes the reducer with the factor `b` set to `(a * b * b) % n`
/// Requires that `a * (n - 1)` does not overflow.
#[allow(dead_code)]
#[inline]
pub fn squared_with_scale(&self, a: U) -> Self {
debug_assert!(a.widen_mul(self.div - U::ONE).hi().is_zero());
self.with_scaled_num_rem(a * self.num, a * self.rem)
}

/// Computes the reducer with the factor `b` set to `(b * b << s) % n`
/// Requires that `(n - 1) << s` does not overflow.
#[inline]
pub fn squared_with_shift(&self, s: u32) -> Self {
debug_assert!((self.div - U::ONE).leading_zeros() >= s);
self.with_scaled_num_rem(self.num << s, self.rem << s)
}
}

#[cfg(test)]
mod test {
use super::Reducer;

#[test]
fn u8_all() {
for y in 1..=128_u8 {
for r in 0..y {
let m = Reducer::new(r, y);
assert_eq!(m.quo, ((r as f32 * 256.0) / (y as f32)) as u8);
for x in 0..=u8::MAX {
let (quo, rem) = m.mul_into_div_rem(x);

let q0 = x as u32 * r as u32 / y as u32;
let r0 = x as u32 * r as u32 % y as u32;
assert_eq!(
(quo as u32, rem as u32),
(q0, r0),
"\n\
{x} * {r} = {xr}\n\
expected: = {q0} * {y} + {r0}\n\
returned: = {quo} * {y} + {rem} (== {})\n",
quo as u32 * y as u32 + rem as u32,
xr = x as u32 * r as u32,
);
}
for s in 0..=y.leading_zeros() {
assert_eq!(
m.squared_with_shift(s),
Reducer::new(((r << s) as u32 * r as u32 % y as u32) as u8, y)
);
}
for a in 0..=u8::MAX {
if a.checked_mul(y).is_some() {
let abb = a as u32 * r as u32 * r as u32;
assert_eq!(
m.squared_with_scale(a),
Reducer::new((abb % y as u32) as u8, y)
);
} else {
break;
}
}
for x0 in 0..=u8::MAX {
if m.num == 0 || x0 as u32 * m.rem as u32 % m.num as u32 != 0 {
continue;
}
let y0 = x0 as u32 * m.rem as u32 / m.num as u32;
let Ok(y0) = u8::try_from(y0) else { continue };

assert_eq!(
m.with_scaled_num_rem(x0, y0),
Reducer::new((x0 as u32 * m.num as u32 % y as u32) as u8, y)
);
}
}
}
}
}
1 change: 1 addition & 0 deletions libm/src/math/support/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub use hex_float::hf16;
pub use hex_float::hf128;
#[allow(unused_imports)]
pub use hex_float::{Hexf, hf32, hf64};
pub(crate) use int_traits::Reducer;
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt};

/// Hint to the compiler that the current path is cold.
Expand Down
Loading