From bf85502f3dcee1ac7c82d1d528789d0bf5578201 Mon Sep 17 00:00:00 2001 From: quaternic <57393910+quaternic@users.noreply.github.com> Date: Thu, 1 May 2025 18:06:31 +0300 Subject: [PATCH 1/2] Implement Barrett reduction for modular multiplication --- libm/src/math/support/int_traits.rs | 3 + libm/src/math/support/int_traits/mod_mul.rs | 225 ++++++++++++++++++++ libm/src/math/support/mod.rs | 1 + 3 files changed, 229 insertions(+) create mode 100644 libm/src/math/support/int_traits/mod_mul.rs diff --git a/libm/src/math/support/int_traits.rs b/libm/src/math/support/int_traits.rs index 3ec1faba1..1ef84d52a 100644 --- a/libm/src/math/support/int_traits.rs +++ b/libm/src/math/support/int_traits.rs @@ -1,5 +1,8 @@ use core::{cmp, fmt, ops}; +mod mod_mul; +pub(crate) use mod_mul::Reducer; + /// Minimal integer implementations needed on all integer types, including wide integers. pub trait MinInt: Copy diff --git a/libm/src/math/support/int_traits/mod_mul.rs b/libm/src/math/support/int_traits/mod_mul.rs new file mode 100644 index 000000000..af8f4a398 --- /dev/null +++ b/libm/src/math/support/int_traits/mod_mul.rs @@ -0,0 +1,225 @@ +use super::{DInt, HInt, Int}; + +/// Barrett reduction using the constant `R == (1 << K) == (1 << U::BITS)` +/// +/// More specifically, implements single-word [Barrett multiplication] +/// (https://en.wikipedia.org/wiki/Barrett_reduction#Single-word_Barrett_multiplication) +/// and [division] +/// (https://en.wikipedia.org/wiki/Barrett_reduction#Barrett_Division) +/// for unsigned integers. +/// +/// After constructing as `Reducer::new(b, n)`, +/// provides operations to efficiently compute +/// - `(a * b) / n` and `(a * b) % n` +/// - `Reducer::new((a * b * b) % n, n)`, as long as `a * (n - 1) < R` +#[derive(Clone, Copy, PartialEq, Eq, Debug)] +pub(crate) struct Reducer { + // the multiplying factor `b in 0..n` + num: U, + // the modulus `n in 1..=R/2` + div: U, + // the precomputed quotient, `q = (b << K) / n` + quo: U, + // the remainder of that division, `r = (b << K) % n`, + // (could always be recomputed as `(b << K) - q * n`, + // but it is convenient to save) + rem: U, +} + +impl Reducer +where + U: Int + HInt, + U::D: core::ops::Div, + U::D: core::ops::Rem, +{ + /// Requires `num < div <= R/2`, will panic otherwise + #[inline] + pub fn new(num: U, div: U) -> Self { + let _0 = U::ZERO; + let _1 = U::ONE; + + assert!(num < div); + assert!(div.wrapping_sub(_1).leading_zeros() >= 1); + + let bk = num.widen_hi(); + let n = div.widen(); + let quo = (bk / n).lo(); + let rem = (bk % n).lo(); + + Self { num, div, quo, rem } + } +} + +impl Reducer +where + U: Int + HInt, + U::D: Int, +{ + /// Return the unique pair `(quotient, remainder)` + /// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < n` + #[inline] + pub fn mul_into_div_rem(&self, a: U) -> (U, U) { + let (q, mut r) = self.mul_into_unnormalized_div_rem(a); + // The unnormalized remainder is still guaranteed to be less than `2n`, so + // one checked subtraction is sufficient. + (q + U::cast_from(self.fixup(&mut r) as u8), r) + } + + #[inline(always)] + pub fn fixup(&self, x: &mut U) -> bool { + x.checked_sub(self.div).map(|r| *x = r).is_some() + } + + /// Return some pair `(quotient, remainder)` + /// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < 2n` + #[inline] + pub fn mul_into_unnormalized_div_rem(&self, a: U) -> (U, U) { + // General idea: Estimate the quotient `quotient = t in 0..a` s.t. + // the remainder `ab - tn` is close to zero, so `t ~= ab / n` + + // Note: we use `R == 1 << U::BITS`, which means that + // - wrapping arithmetic with `U` is modulo `R` + // - all inputs are less than `R` + + // Range analysis: + // + // Using the definition of euclidean division on the two divisions done: + // ``` + // bR = qn + r, with 0 <= r < n + // aq = tR + s, with 0 <= s < R + // ``` + let (_s, t) = a.widen_mul(self.quo).lo_hi(); + // Then + // ``` + // (ab - tn)R + // = abR - ntR + // = a(qn + r) - n(aq - s) + // = ar + ns + // ``` + #[cfg(debug_assertions)] + { + assert!(t < a || (a == t && t.is_zero())); + let ab_tn = a.widen_mul(self.num) - t.widen_mul(self.div); + let ar_ns = a.widen_mul(self.rem) + _s.widen_mul(self.div); + assert!(ab_tn.hi().is_zero()); + assert!(ar_ns.lo().is_zero()); + assert_eq!(ab_tn.lo(), ar_ns.hi()); + } + // Since `s < R` and `r < n`, + // ``` + // 0 <= ns < nR + // 0 <= ar < an + // 0 <= (ab - tn) == (ar + ns)/R < n(1 + a/R) + // ``` + // Since `a < R` and we check on construction that `n <= R/2`, the result + // is `0 <= ab - tn < R`, so it can be computed modulo `R` + // even though the intermediate terms generally wrap. + let ab = a.wrapping_mul(self.num); + let tn = t.wrapping_mul(self.div); + (t, ab.wrapping_sub(tn)) + } + + /// Constructs a new reducer with `b` set to `(ab * b) % n` + /// + /// Requires `r * ab == ra * b`, where `r = bR % n`. + #[inline(always)] + fn with_scaled_num_rem(&self, ab: U, ra: U) -> Self { + debug_assert_eq!(ab.widen_mul(self.rem), ra.widen_mul(self.num)); + // The new factor `v = abb mod n`: + let (_, v) = self.mul_into_div_rem(ab); + + // `rab = cn + d`, where `0 <= d < n` + let (c, d) = self.mul_into_div_rem(ra); + + // We need `abbR = Xn + Y`: + // abbR + // = ab(qn + r) + // = abqn + rab + // = abqn + cn + d + // = (abq + c)n + d + + Self { + num: v, + div: self.div, + quo: self.quo.wrapping_mul(ab).wrapping_add(c), + rem: d, + } + } + + /// Computes the reducer with the factor `b` set to `(a * b * b) % n` + /// Requires that `a * (n - 1)` does not overflow. + #[allow(dead_code)] + #[inline] + pub fn squared_with_scale(&self, a: U) -> Self { + debug_assert!(a.widen_mul(self.div - U::ONE).hi().is_zero()); + self.with_scaled_num_rem(a * self.num, a * self.rem) + } + + /// Computes the reducer with the factor `b` set to `(b * b << s) % n` + /// Requires that `(n - 1) << s` does not overflow. + #[inline] + pub fn squared_with_shift(&self, s: u32) -> Self { + debug_assert!((self.div - U::ONE).leading_zeros() >= s); + self.with_scaled_num_rem(self.num << s, self.rem << s) + } +} + +#[cfg(test)] +mod test { + use super::Reducer; + + #[test] + fn u8_all() { + for y in 1..=128_u8 { + for r in 0..y { + let m = Reducer::new(r, y); + assert_eq!(m.quo, ((r as f32 * 256.0) / (y as f32)) as u8); + for x in 0..=u8::MAX { + let (quo, rem) = m.mul_into_div_rem(x); + + let q0 = x as u32 * r as u32 / y as u32; + let r0 = x as u32 * r as u32 % y as u32; + assert_eq!( + (quo as u32, rem as u32), + (q0, r0), + "\n\ + {x} * {r} = {xr}\n\ + expected: = {q0} * {y} + {r0}\n\ + returned: = {quo} * {y} + {rem} (== {})\n", + quo as u32 * y as u32 + rem as u32, + xr = x as u32 * r as u32, + ); + } + for s in 0..=y.leading_zeros() { + assert_eq!( + m.squared_with_shift(s), + Reducer::new(((r << s) as u32 * r as u32 % y as u32) as u8, y) + ); + } + for a in 0..=u8::MAX { + if a.checked_mul(y).is_some() { + let abb = a as u32 * r as u32 * r as u32; + assert_eq!( + m.squared_with_scale(a), + Reducer::new((abb % y as u32) as u8, y) + ); + } else { + break; + } + } + for x0 in 0..=u8::MAX { + if m.num == 0 || x0 as u32 * m.rem as u32 % m.num as u32 != 0 { + continue; + } + let y0 = x0 as u32 * m.rem as u32 / m.num as u32; + let Ok(y0) = u8::try_from(y0) else { continue }; + + assert_eq!( + m.with_scaled_num_rem(x0, y0), + Reducer::new((x0 as u32 * m.num as u32 % y as u32) as u8, y) + ); + } + } + } + } +} diff --git a/libm/src/math/support/mod.rs b/libm/src/math/support/mod.rs index ee3f2bbdf..330217a7a 100644 --- a/libm/src/math/support/mod.rs +++ b/libm/src/math/support/mod.rs @@ -20,6 +20,7 @@ pub use hex_float::hf16; pub use hex_float::hf128; #[allow(unused_imports)] pub use hex_float::{Hexf, hf32, hf64}; +pub(crate) use int_traits::Reducer; pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt}; /// Hint to the compiler that the current path is cold. From c811592b46797aae0769852d93fb1f9a90a8c9cb Mon Sep 17 00:00:00 2001 From: quaternic <57393910+quaternic@users.noreply.github.com> Date: Thu, 1 May 2025 18:07:48 +0300 Subject: [PATCH 2/2] Optimize performance of fmod with Barrett multiplication --- libm/src/math/generic/fmod.rs | 93 +++++++++++++++++++++ libm/src/math/support/int_traits/mod_mul.rs | 13 ++- 2 files changed, 98 insertions(+), 8 deletions(-) diff --git a/libm/src/math/generic/fmod.rs b/libm/src/math/generic/fmod.rs index e9898012f..183250c03 100644 --- a/libm/src/math/generic/fmod.rs +++ b/libm/src/math/generic/fmod.rs @@ -1,5 +1,6 @@ /* SPDX-License-Identifier: MIT OR Apache-2.0 */ use super::super::{CastFrom, Float, Int, MinInt}; +use crate::support::{DInt, HInt, Reducer}; #[inline] pub fn fmod(x: F, y: F) -> F { @@ -59,6 +60,33 @@ fn into_sig_exp(mut bits: F::Int) -> (F::Int, u32) { /// Compute the remainder `(x * 2.pow(e)) % y` without overflow. fn reduction(mut x: I, e: u32, y: I) -> I { + // FIXME: This is a temporary hack to get around the lack of `u256 / u256`. + // Actually, the algorithm only needs the operation `(x << I::BITS) / y` + // where `x < y`. That is, a division `u256 / u128` where the quotient must + // not overflow `u128` would be sufficient for `f128`. + unsafe { + use core::mem::transmute_copy; + if I::BITS == 64 { + let x = transmute_copy::(&x); + let y = transmute_copy::(&y); + let r = fast_reduction::(x, e, y); + return transmute_copy::(&r); + } + if I::BITS == 32 { + let x = transmute_copy::(&x); + let y = transmute_copy::(&y); + let r = fast_reduction::(x, e, y); + return transmute_copy::(&r); + } + #[cfg(f16_enabled)] + if I::BITS == 16 { + let x = transmute_copy::(&x); + let y = transmute_copy::(&y); + let r = fast_reduction::(x, e, y); + return transmute_copy::(&r); + } + } + x %= y; for _ in 0..e { x <<= 1; @@ -66,3 +94,68 @@ fn reduction(mut x: I, e: u32, y: I) -> I { } x } + +trait SafeShift: Float { + // How many guaranteed leading zeros do the values have? + // A normalized floating point mantissa has `EXP_BITS` guaranteed leading + // zeros (exludes the implicit bit, but includes the now-zeroed sign bit) + // `-1` because we want to shift by either `BASE_SHIFT` or `BASE_SHIFT + 1` + const BASE_SHIFT: u32 = Self::EXP_BITS - 1; +} +impl SafeShift for F {} + +fn fast_reduction(x: I, e: u32, y: I) -> I +where + F: Float, + I: Int + HInt, + I::D: Int + DInt, +{ + let _0 = I::ZERO; + let _1 = I::ONE; + + if y == _1 { + return _0; + } + + if e <= F::BASE_SHIFT { + return (x << e) % y; + } + + // Find least depth s.t. `(e >> depth) < I::BITS` + let depth = (I::BITS - 1) + .leading_zeros() + .saturating_sub(e.leading_zeros()); + + let initial = (e >> depth) - F::BASE_SHIFT; + + let max_rem = y.wrapping_sub(_1); + let max_ilog2 = max_rem.ilog2(); + let mut pow2 = _1 << max_ilog2.min(initial); + for _ in max_ilog2..initial { + pow2 <<= 1; + pow2 = pow2.checked_sub(y).unwrap_or(pow2); + } + + // At each step `k in [depth, ..., 0]`, + // `p` is `(e >> k) - BASE_SHIFT` + // `m` is `(1 << p) % y` + let mut k = depth; + let mut p = initial; + let mut m = Reducer::new(pow2, y); + + while k > 0 { + k -= 1; + p = p + p + F::BASE_SHIFT; + if e & (1 << k) != 0 { + m = m.squared_with_shift(F::BASE_SHIFT + 1); + p += 1; + } else { + m = m.squared_with_shift(F::BASE_SHIFT); + }; + + debug_assert!(p == (e >> k) - F::BASE_SHIFT); + } + + // (x << BASE_SHIFT) * (1 << p) == x << e + m.mul_into_div_rem(x << F::BASE_SHIFT).1 +} diff --git a/libm/src/math/support/int_traits/mod_mul.rs b/libm/src/math/support/int_traits/mod_mul.rs index af8f4a398..c770a122b 100644 --- a/libm/src/math/support/int_traits/mod_mul.rs +++ b/libm/src/math/support/int_traits/mod_mul.rs @@ -2,14 +2,11 @@ use super::{DInt, HInt, Int}; /// Barrett reduction using the constant `R == (1 << K) == (1 << U::BITS)` /// -/// More specifically, implements single-word [Barrett multiplication] -/// (https://en.wikipedia.org/wiki/Barrett_reduction#Single-word_Barrett_multiplication) -/// and [division] -/// (https://en.wikipedia.org/wiki/Barrett_reduction#Barrett_Division) -/// for unsigned integers. +/// For a more detailed description, see +/// . /// /// After constructing as `Reducer::new(b, n)`, -/// provides operations to efficiently compute +/// has operations to efficiently compute /// - `(a * b) / n` and `(a * b) % n` /// - `Reducer::new((a * b * b) % n, n)`, as long as `a * (n - 1) < R` #[derive(Clone, Copy, PartialEq, Eq, Debug)] @@ -103,7 +100,7 @@ where let ar_ns = a.widen_mul(self.rem) + _s.widen_mul(self.div); assert!(ab_tn.hi().is_zero()); assert!(ar_ns.lo().is_zero()); - assert_eq!(ab_tn.lo(), ar_ns.hi()); + assert!(ab_tn.lo() == ar_ns.hi()); } // Since `s < R` and `r < n`, // ``` @@ -124,7 +121,7 @@ where /// Requires `r * ab == ra * b`, where `r = bR % n`. #[inline(always)] fn with_scaled_num_rem(&self, ab: U, ra: U) -> Self { - debug_assert_eq!(ab.widen_mul(self.rem), ra.widen_mul(self.num)); + debug_assert!(ab.widen_mul(self.rem) == ra.widen_mul(self.num)); // The new factor `v = abb mod n`: let (_, v) = self.mul_into_div_rem(ab);