Skip to content

Commit bf85502

Browse files
committed
Implement Barrett reduction for modular multiplication
1 parent f456aa8 commit bf85502

File tree

3 files changed

+229
-0
lines changed

3 files changed

+229
-0
lines changed

libm/src/math/support/int_traits.rs

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use core::{cmp, fmt, ops};
22

3+
mod mod_mul;
4+
pub(crate) use mod_mul::Reducer;
5+
36
/// Minimal integer implementations needed on all integer types, including wide integers.
47
pub trait MinInt:
58
Copy
+225
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
use super::{DInt, HInt, Int};
2+
3+
/// Barrett reduction using the constant `R == (1 << K) == (1 << U::BITS)`
4+
///
5+
/// More specifically, implements single-word [Barrett multiplication]
6+
/// (https://en.wikipedia.org/wiki/Barrett_reduction#Single-word_Barrett_multiplication)
7+
/// and [division]
8+
/// (https://en.wikipedia.org/wiki/Barrett_reduction#Barrett_Division)
9+
/// for unsigned integers.
10+
///
11+
/// After constructing as `Reducer::new(b, n)`,
12+
/// provides operations to efficiently compute
13+
/// - `(a * b) / n` and `(a * b) % n`
14+
/// - `Reducer::new((a * b * b) % n, n)`, as long as `a * (n - 1) < R`
15+
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
16+
pub(crate) struct Reducer<U> {
17+
// the multiplying factor `b in 0..n`
18+
num: U,
19+
// the modulus `n in 1..=R/2`
20+
div: U,
21+
// the precomputed quotient, `q = (b << K) / n`
22+
quo: U,
23+
// the remainder of that division, `r = (b << K) % n`,
24+
// (could always be recomputed as `(b << K) - q * n`,
25+
// but it is convenient to save)
26+
rem: U,
27+
}
28+
29+
impl<U> Reducer<U>
30+
where
31+
U: Int + HInt,
32+
U::D: core::ops::Div<Output = U::D>,
33+
U::D: core::ops::Rem<Output = U::D>,
34+
{
35+
/// Requires `num < div <= R/2`, will panic otherwise
36+
#[inline]
37+
pub fn new(num: U, div: U) -> Self {
38+
let _0 = U::ZERO;
39+
let _1 = U::ONE;
40+
41+
assert!(num < div);
42+
assert!(div.wrapping_sub(_1).leading_zeros() >= 1);
43+
44+
let bk = num.widen_hi();
45+
let n = div.widen();
46+
let quo = (bk / n).lo();
47+
let rem = (bk % n).lo();
48+
49+
Self { num, div, quo, rem }
50+
}
51+
}
52+
53+
impl<U> Reducer<U>
54+
where
55+
U: Int + HInt,
56+
U::D: Int,
57+
{
58+
/// Return the unique pair `(quotient, remainder)`
59+
/// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < n`
60+
#[inline]
61+
pub fn mul_into_div_rem(&self, a: U) -> (U, U) {
62+
let (q, mut r) = self.mul_into_unnormalized_div_rem(a);
63+
// The unnormalized remainder is still guaranteed to be less than `2n`, so
64+
// one checked subtraction is sufficient.
65+
(q + U::cast_from(self.fixup(&mut r) as u8), r)
66+
}
67+
68+
#[inline(always)]
69+
pub fn fixup(&self, x: &mut U) -> bool {
70+
x.checked_sub(self.div).map(|r| *x = r).is_some()
71+
}
72+
73+
/// Return some pair `(quotient, remainder)`
74+
/// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < 2n`
75+
#[inline]
76+
pub fn mul_into_unnormalized_div_rem(&self, a: U) -> (U, U) {
77+
// General idea: Estimate the quotient `quotient = t in 0..a` s.t.
78+
// the remainder `ab - tn` is close to zero, so `t ~= ab / n`
79+
80+
// Note: we use `R == 1 << U::BITS`, which means that
81+
// - wrapping arithmetic with `U` is modulo `R`
82+
// - all inputs are less than `R`
83+
84+
// Range analysis:
85+
//
86+
// Using the definition of euclidean division on the two divisions done:
87+
// ```
88+
// bR = qn + r, with 0 <= r < n
89+
// aq = tR + s, with 0 <= s < R
90+
// ```
91+
let (_s, t) = a.widen_mul(self.quo).lo_hi();
92+
// Then
93+
// ```
94+
// (ab - tn)R
95+
// = abR - ntR
96+
// = a(qn + r) - n(aq - s)
97+
// = ar + ns
98+
// ```
99+
#[cfg(debug_assertions)]
100+
{
101+
assert!(t < a || (a == t && t.is_zero()));
102+
let ab_tn = a.widen_mul(self.num) - t.widen_mul(self.div);
103+
let ar_ns = a.widen_mul(self.rem) + _s.widen_mul(self.div);
104+
assert!(ab_tn.hi().is_zero());
105+
assert!(ar_ns.lo().is_zero());
106+
assert_eq!(ab_tn.lo(), ar_ns.hi());
107+
}
108+
// Since `s < R` and `r < n`,
109+
// ```
110+
// 0 <= ns < nR
111+
// 0 <= ar < an
112+
// 0 <= (ab - tn) == (ar + ns)/R < n(1 + a/R)
113+
// ```
114+
// Since `a < R` and we check on construction that `n <= R/2`, the result
115+
// is `0 <= ab - tn < R`, so it can be computed modulo `R`
116+
// even though the intermediate terms generally wrap.
117+
let ab = a.wrapping_mul(self.num);
118+
let tn = t.wrapping_mul(self.div);
119+
(t, ab.wrapping_sub(tn))
120+
}
121+
122+
/// Constructs a new reducer with `b` set to `(ab * b) % n`
123+
///
124+
/// Requires `r * ab == ra * b`, where `r = bR % n`.
125+
#[inline(always)]
126+
fn with_scaled_num_rem(&self, ab: U, ra: U) -> Self {
127+
debug_assert_eq!(ab.widen_mul(self.rem), ra.widen_mul(self.num));
128+
// The new factor `v = abb mod n`:
129+
let (_, v) = self.mul_into_div_rem(ab);
130+
131+
// `rab = cn + d`, where `0 <= d < n`
132+
let (c, d) = self.mul_into_div_rem(ra);
133+
134+
// We need `abbR = Xn + Y`:
135+
// abbR
136+
// = ab(qn + r)
137+
// = abqn + rab
138+
// = abqn + cn + d
139+
// = (abq + c)n + d
140+
141+
Self {
142+
num: v,
143+
div: self.div,
144+
quo: self.quo.wrapping_mul(ab).wrapping_add(c),
145+
rem: d,
146+
}
147+
}
148+
149+
/// Computes the reducer with the factor `b` set to `(a * b * b) % n`
150+
/// Requires that `a * (n - 1)` does not overflow.
151+
#[allow(dead_code)]
152+
#[inline]
153+
pub fn squared_with_scale(&self, a: U) -> Self {
154+
debug_assert!(a.widen_mul(self.div - U::ONE).hi().is_zero());
155+
self.with_scaled_num_rem(a * self.num, a * self.rem)
156+
}
157+
158+
/// Computes the reducer with the factor `b` set to `(b * b << s) % n`
159+
/// Requires that `(n - 1) << s` does not overflow.
160+
#[inline]
161+
pub fn squared_with_shift(&self, s: u32) -> Self {
162+
debug_assert!((self.div - U::ONE).leading_zeros() >= s);
163+
self.with_scaled_num_rem(self.num << s, self.rem << s)
164+
}
165+
}
166+
167+
#[cfg(test)]
168+
mod test {
169+
use super::Reducer;
170+
171+
#[test]
172+
fn u8_all() {
173+
for y in 1..=128_u8 {
174+
for r in 0..y {
175+
let m = Reducer::new(r, y);
176+
assert_eq!(m.quo, ((r as f32 * 256.0) / (y as f32)) as u8);
177+
for x in 0..=u8::MAX {
178+
let (quo, rem) = m.mul_into_div_rem(x);
179+
180+
let q0 = x as u32 * r as u32 / y as u32;
181+
let r0 = x as u32 * r as u32 % y as u32;
182+
assert_eq!(
183+
(quo as u32, rem as u32),
184+
(q0, r0),
185+
"\n\
186+
{x} * {r} = {xr}\n\
187+
expected: = {q0} * {y} + {r0}\n\
188+
returned: = {quo} * {y} + {rem} (== {})\n",
189+
quo as u32 * y as u32 + rem as u32,
190+
xr = x as u32 * r as u32,
191+
);
192+
}
193+
for s in 0..=y.leading_zeros() {
194+
assert_eq!(
195+
m.squared_with_shift(s),
196+
Reducer::new(((r << s) as u32 * r as u32 % y as u32) as u8, y)
197+
);
198+
}
199+
for a in 0..=u8::MAX {
200+
if a.checked_mul(y).is_some() {
201+
let abb = a as u32 * r as u32 * r as u32;
202+
assert_eq!(
203+
m.squared_with_scale(a),
204+
Reducer::new((abb % y as u32) as u8, y)
205+
);
206+
} else {
207+
break;
208+
}
209+
}
210+
for x0 in 0..=u8::MAX {
211+
if m.num == 0 || x0 as u32 * m.rem as u32 % m.num as u32 != 0 {
212+
continue;
213+
}
214+
let y0 = x0 as u32 * m.rem as u32 / m.num as u32;
215+
let Ok(y0) = u8::try_from(y0) else { continue };
216+
217+
assert_eq!(
218+
m.with_scaled_num_rem(x0, y0),
219+
Reducer::new((x0 as u32 * m.num as u32 % y as u32) as u8, y)
220+
);
221+
}
222+
}
223+
}
224+
}
225+
}

libm/src/math/support/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ pub use hex_float::hf16;
2020
pub use hex_float::hf128;
2121
#[allow(unused_imports)]
2222
pub use hex_float::{Hexf, hf32, hf64};
23+
pub(crate) use int_traits::Reducer;
2324
pub use int_traits::{CastFrom, CastInto, DInt, HInt, Int, MinInt};
2425

2526
/// Hint to the compiler that the current path is cold.

0 commit comments

Comments
 (0)