|
| 1 | +use super::{DInt, HInt, Int}; |
| 2 | + |
| 3 | +/// Barrett reduction using the constant `R == (1 << K) == (1 << U::BITS)` |
| 4 | +/// |
| 5 | +/// More specifically, implements single-word [Barrett multiplication] |
| 6 | +/// (https://en.wikipedia.org/wiki/Barrett_reduction#Single-word_Barrett_multiplication) |
| 7 | +/// and [division] |
| 8 | +/// (https://en.wikipedia.org/wiki/Barrett_reduction#Barrett_Division) |
| 9 | +/// for unsigned integers. |
| 10 | +/// |
| 11 | +/// After constructing as `Reducer::new(b, n)`, |
| 12 | +/// provides operations to efficiently compute |
| 13 | +/// - `(a * b) / n` and `(a * b) % n` |
| 14 | +/// - `Reducer::new((a * b * b) % n, n)`, as long as `a * (n - 1) < R` |
| 15 | +#[derive(Clone, Copy, PartialEq, Eq, Debug)] |
| 16 | +pub(crate) struct Reducer<U> { |
| 17 | + // the multiplying factor `b in 0..n` |
| 18 | + num: U, |
| 19 | + // the modulus `n in 1..=R/2` |
| 20 | + div: U, |
| 21 | + // the precomputed quotient, `q = (b << K) / n` |
| 22 | + quo: U, |
| 23 | + // the remainder of that division, `r = (b << K) % n`, |
| 24 | + // (could always be recomputed as `(b << K) - q * n`, |
| 25 | + // but it is convenient to save) |
| 26 | + rem: U, |
| 27 | +} |
| 28 | + |
| 29 | +impl<U> Reducer<U> |
| 30 | +where |
| 31 | + U: Int + HInt, |
| 32 | + U::D: core::ops::Div<Output = U::D>, |
| 33 | + U::D: core::ops::Rem<Output = U::D>, |
| 34 | +{ |
| 35 | + /// Requires `num < div <= R/2`, will panic otherwise |
| 36 | + #[inline] |
| 37 | + pub fn new(num: U, div: U) -> Self { |
| 38 | + let _0 = U::ZERO; |
| 39 | + let _1 = U::ONE; |
| 40 | + |
| 41 | + assert!(num < div); |
| 42 | + assert!(div.wrapping_sub(_1).leading_zeros() >= 1); |
| 43 | + |
| 44 | + let bk = num.widen_hi(); |
| 45 | + let n = div.widen(); |
| 46 | + let quo = (bk / n).lo(); |
| 47 | + let rem = (bk % n).lo(); |
| 48 | + |
| 49 | + Self { num, div, quo, rem } |
| 50 | + } |
| 51 | +} |
| 52 | + |
| 53 | +impl<U> Reducer<U> |
| 54 | +where |
| 55 | + U: Int + HInt, |
| 56 | + U::D: Int, |
| 57 | +{ |
| 58 | + /// Return the unique pair `(quotient, remainder)` |
| 59 | + /// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < n` |
| 60 | + #[inline] |
| 61 | + pub fn mul_into_div_rem(&self, a: U) -> (U, U) { |
| 62 | + let (q, mut r) = self.mul_into_unnormalized_div_rem(a); |
| 63 | + // The unnormalized remainder is still guaranteed to be less than `2n`, so |
| 64 | + // one checked subtraction is sufficient. |
| 65 | + (q + U::cast_from(self.fixup(&mut r) as u8), r) |
| 66 | + } |
| 67 | + |
| 68 | + #[inline(always)] |
| 69 | + pub fn fixup(&self, x: &mut U) -> bool { |
| 70 | + x.checked_sub(self.div).map(|r| *x = r).is_some() |
| 71 | + } |
| 72 | + |
| 73 | + /// Return some pair `(quotient, remainder)` |
| 74 | + /// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < 2n` |
| 75 | + #[inline] |
| 76 | + pub fn mul_into_unnormalized_div_rem(&self, a: U) -> (U, U) { |
| 77 | + // General idea: Estimate the quotient `quotient = t in 0..a` s.t. |
| 78 | + // the remainder `ab - tn` is close to zero, so `t ~= ab / n` |
| 79 | + |
| 80 | + // Note: we use `R == 1 << U::BITS`, which means that |
| 81 | + // - wrapping arithmetic with `U` is modulo `R` |
| 82 | + // - all inputs are less than `R` |
| 83 | + |
| 84 | + // Range analysis: |
| 85 | + // |
| 86 | + // Using the definition of euclidean division on the two divisions done: |
| 87 | + // ``` |
| 88 | + // bR = qn + r, with 0 <= r < n |
| 89 | + // aq = tR + s, with 0 <= s < R |
| 90 | + // ``` |
| 91 | + let (_s, t) = a.widen_mul(self.quo).lo_hi(); |
| 92 | + // Then |
| 93 | + // ``` |
| 94 | + // (ab - tn)R |
| 95 | + // = abR - ntR |
| 96 | + // = a(qn + r) - n(aq - s) |
| 97 | + // = ar + ns |
| 98 | + // ``` |
| 99 | + #[cfg(debug_assertions)] |
| 100 | + { |
| 101 | + assert!(t < a || (a == t && t.is_zero())); |
| 102 | + let ab_tn = a.widen_mul(self.num) - t.widen_mul(self.div); |
| 103 | + let ar_ns = a.widen_mul(self.rem) + _s.widen_mul(self.div); |
| 104 | + assert!(ab_tn.hi().is_zero()); |
| 105 | + assert!(ar_ns.lo().is_zero()); |
| 106 | + assert_eq!(ab_tn.lo(), ar_ns.hi()); |
| 107 | + } |
| 108 | + // Since `s < R` and `r < n`, |
| 109 | + // ``` |
| 110 | + // 0 <= ns < nR |
| 111 | + // 0 <= ar < an |
| 112 | + // 0 <= (ab - tn) == (ar + ns)/R < n(1 + a/R) |
| 113 | + // ``` |
| 114 | + // Since `a < R` and we check on construction that `n <= R/2`, the result |
| 115 | + // is `0 <= ab - tn < R`, so it can be computed modulo `R` |
| 116 | + // even though the intermediate terms generally wrap. |
| 117 | + let ab = a.wrapping_mul(self.num); |
| 118 | + let tn = t.wrapping_mul(self.div); |
| 119 | + (t, ab.wrapping_sub(tn)) |
| 120 | + } |
| 121 | + |
| 122 | + /// Constructs a new reducer with `b` set to `(ab * b) % n` |
| 123 | + /// |
| 124 | + /// Requires `r * ab == ra * b`, where `r = bR % n`. |
| 125 | + #[inline(always)] |
| 126 | + fn with_scaled_num_rem(&self, ab: U, ra: U) -> Self { |
| 127 | + debug_assert_eq!(ab.widen_mul(self.rem), ra.widen_mul(self.num)); |
| 128 | + // The new factor `v = abb mod n`: |
| 129 | + let (_, v) = self.mul_into_div_rem(ab); |
| 130 | + |
| 131 | + // `rab = cn + d`, where `0 <= d < n` |
| 132 | + let (c, d) = self.mul_into_div_rem(ra); |
| 133 | + |
| 134 | + // We need `abbR = Xn + Y`: |
| 135 | + // abbR |
| 136 | + // = ab(qn + r) |
| 137 | + // = abqn + rab |
| 138 | + // = abqn + cn + d |
| 139 | + // = (abq + c)n + d |
| 140 | + |
| 141 | + Self { |
| 142 | + num: v, |
| 143 | + div: self.div, |
| 144 | + quo: self.quo.wrapping_mul(ab).wrapping_add(c), |
| 145 | + rem: d, |
| 146 | + } |
| 147 | + } |
| 148 | + |
| 149 | + /// Computes the reducer with the factor `b` set to `(a * b * b) % n` |
| 150 | + /// Requires that `a * (n - 1)` does not overflow. |
| 151 | + #[allow(dead_code)] |
| 152 | + #[inline] |
| 153 | + pub fn squared_with_scale(&self, a: U) -> Self { |
| 154 | + debug_assert!(a.widen_mul(self.div - U::ONE).hi().is_zero()); |
| 155 | + self.with_scaled_num_rem(a * self.num, a * self.rem) |
| 156 | + } |
| 157 | + |
| 158 | + /// Computes the reducer with the factor `b` set to `(b * b << s) % n` |
| 159 | + /// Requires that `(n - 1) << s` does not overflow. |
| 160 | + #[inline] |
| 161 | + pub fn squared_with_shift(&self, s: u32) -> Self { |
| 162 | + debug_assert!((self.div - U::ONE).leading_zeros() >= s); |
| 163 | + self.with_scaled_num_rem(self.num << s, self.rem << s) |
| 164 | + } |
| 165 | +} |
| 166 | + |
| 167 | +#[cfg(test)] |
| 168 | +mod test { |
| 169 | + use super::Reducer; |
| 170 | + |
| 171 | + #[test] |
| 172 | + fn u8_all() { |
| 173 | + for y in 1..=128_u8 { |
| 174 | + for r in 0..y { |
| 175 | + let m = Reducer::new(r, y); |
| 176 | + assert_eq!(m.quo, ((r as f32 * 256.0) / (y as f32)) as u8); |
| 177 | + for x in 0..=u8::MAX { |
| 178 | + let (quo, rem) = m.mul_into_div_rem(x); |
| 179 | + |
| 180 | + let q0 = x as u32 * r as u32 / y as u32; |
| 181 | + let r0 = x as u32 * r as u32 % y as u32; |
| 182 | + assert_eq!( |
| 183 | + (quo as u32, rem as u32), |
| 184 | + (q0, r0), |
| 185 | + "\n\ |
| 186 | + {x} * {r} = {xr}\n\ |
| 187 | + expected: = {q0} * {y} + {r0}\n\ |
| 188 | + returned: = {quo} * {y} + {rem} (== {})\n", |
| 189 | + quo as u32 * y as u32 + rem as u32, |
| 190 | + xr = x as u32 * r as u32, |
| 191 | + ); |
| 192 | + } |
| 193 | + for s in 0..=y.leading_zeros() { |
| 194 | + assert_eq!( |
| 195 | + m.squared_with_shift(s), |
| 196 | + Reducer::new(((r << s) as u32 * r as u32 % y as u32) as u8, y) |
| 197 | + ); |
| 198 | + } |
| 199 | + for a in 0..=u8::MAX { |
| 200 | + if a.checked_mul(y).is_some() { |
| 201 | + let abb = a as u32 * r as u32 * r as u32; |
| 202 | + assert_eq!( |
| 203 | + m.squared_with_scale(a), |
| 204 | + Reducer::new((abb % y as u32) as u8, y) |
| 205 | + ); |
| 206 | + } else { |
| 207 | + break; |
| 208 | + } |
| 209 | + } |
| 210 | + for x0 in 0..=u8::MAX { |
| 211 | + if m.num == 0 || x0 as u32 * m.rem as u32 % m.num as u32 != 0 { |
| 212 | + continue; |
| 213 | + } |
| 214 | + let y0 = x0 as u32 * m.rem as u32 / m.num as u32; |
| 215 | + let Ok(y0) = u8::try_from(y0) else { continue }; |
| 216 | + |
| 217 | + assert_eq!( |
| 218 | + m.with_scaled_num_rem(x0, y0), |
| 219 | + Reducer::new((x0 as u32 * m.num as u32 % y as u32) as u8, y) |
| 220 | + ); |
| 221 | + } |
| 222 | + } |
| 223 | + } |
| 224 | + } |
| 225 | +} |
0 commit comments