-
Notifications
You must be signed in to change notification settings - Fork 247
Improve performance scaling of fmod
using modular exponentiation
#898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
/* SPDX-License-Identifier: MIT OR Apache-2.0 */ | ||
use super::super::{CastFrom, Float, Int, MinInt}; | ||
use crate::support::{DInt, HInt, Reducer}; | ||
|
||
#[inline] | ||
pub fn fmod<F: Float>(x: F, y: F) -> F { | ||
|
@@ -59,10 +60,102 @@ fn into_sig_exp<F: Float>(mut bits: F::Int) -> (F::Int, u32) { | |
|
||
/// Compute the remainder `(x * 2.pow(e)) % y` without overflow. | ||
fn reduction<I: Int>(mut x: I, e: u32, y: I) -> I { | ||
// FIXME: This is a temporary hack to get around the lack of `u256 / u256`. | ||
// Actually, the algorithm only needs the operation `(x << I::BITS) / y` | ||
// where `x < y`. That is, a division `u256 / u128` where the quotient must | ||
// not overflow `u128` would be sufficient for `f128`. | ||
Comment on lines
+63
to
+66
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would this be easier with u256? We have an implementation here
pub struct u256 {
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If
For context, x86's So that's the abstraction I'd like to use; something like unsafe fn unchecked_wide_div_rem<U: HInt>(U::D, U) -> (U, U); But of course, that would be even more work to implement since it doesn't exist yet, and I don't think other arches have a native operation for it. Another idea would be to get rid of the integer division altogether, and compute the reciprocal from the original floating point value. I expect that this could have better performance, but it needs more careful analysis. |
||
unsafe { | ||
use core::mem::transmute_copy; | ||
if I::BITS == 64 { | ||
let x = transmute_copy::<I, u64>(&x); | ||
let y = transmute_copy::<I, u64>(&y); | ||
let r = fast_reduction::<f64, u64>(x, e, y); | ||
return transmute_copy::<u64, I>(&r); | ||
} | ||
if I::BITS == 32 { | ||
let x = transmute_copy::<I, u32>(&x); | ||
let y = transmute_copy::<I, u32>(&y); | ||
let r = fast_reduction::<f32, u32>(x, e, y); | ||
return transmute_copy::<u32, I>(&r); | ||
} | ||
#[cfg(f16_enabled)] | ||
if I::BITS == 16 { | ||
let x = transmute_copy::<I, u16>(&x); | ||
let y = transmute_copy::<I, u16>(&y); | ||
let r = fast_reduction::<f16, u16>(x, e, y); | ||
return transmute_copy::<u16, I>(&r); | ||
} | ||
} | ||
|
||
x %= y; | ||
for _ in 0..e { | ||
x <<= 1; | ||
x = x.checked_sub(y).unwrap_or(x); | ||
} | ||
x | ||
} | ||
|
||
trait SafeShift: Float { | ||
// How many guaranteed leading zeros do the values have? | ||
// A normalized floating point mantissa has `EXP_BITS` guaranteed leading | ||
// zeros (exludes the implicit bit, but includes the now-zeroed sign bit) | ||
// `-1` because we want to shift by either `BASE_SHIFT` or `BASE_SHIFT + 1` | ||
const BASE_SHIFT: u32 = Self::EXP_BITS - 1; | ||
} | ||
impl<F: Float> SafeShift for F {} | ||
|
||
fn fast_reduction<F, I>(x: I, e: u32, y: I) -> I | ||
where | ||
F: Float<Int = I>, | ||
I: Int + HInt, | ||
I::D: Int + DInt<H = I>, | ||
{ | ||
let _0 = I::ZERO; | ||
let _1 = I::ONE; | ||
|
||
if y == _1 { | ||
return _0; | ||
} | ||
|
||
if e <= F::BASE_SHIFT { | ||
return (x << e) % y; | ||
} | ||
|
||
// Find least depth s.t. `(e >> depth) < I::BITS` | ||
let depth = (I::BITS - 1) | ||
.leading_zeros() | ||
.saturating_sub(e.leading_zeros()); | ||
|
||
let initial = (e >> depth) - F::BASE_SHIFT; | ||
|
||
let max_rem = y.wrapping_sub(_1); | ||
let max_ilog2 = max_rem.ilog2(); | ||
let mut pow2 = _1 << max_ilog2.min(initial); | ||
for _ in max_ilog2..initial { | ||
pow2 <<= 1; | ||
pow2 = pow2.checked_sub(y).unwrap_or(pow2); | ||
} | ||
|
||
// At each step `k in [depth, ..., 0]`, | ||
// `p` is `(e >> k) - BASE_SHIFT` | ||
// `m` is `(1 << p) % y` | ||
let mut k = depth; | ||
let mut p = initial; | ||
let mut m = Reducer::new(pow2, y); | ||
|
||
while k > 0 { | ||
k -= 1; | ||
p = p + p + F::BASE_SHIFT; | ||
if e & (1 << k) != 0 { | ||
m = m.squared_with_shift(F::BASE_SHIFT + 1); | ||
p += 1; | ||
} else { | ||
m = m.squared_with_shift(F::BASE_SHIFT); | ||
}; | ||
|
||
debug_assert!(p == (e >> k) - F::BASE_SHIFT); | ||
} | ||
|
||
// (x << BASE_SHIFT) * (1 << p) == x << e | ||
m.mul_into_div_rem(x << F::BASE_SHIFT).1 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
use super::{DInt, HInt, Int}; | ||
|
||
/// Barrett reduction using the constant `R == (1 << K) == (1 << U::BITS)` | ||
/// | ||
/// For a more detailed description, see | ||
/// <https://en.wikipedia.org/wiki/Barrett_reduction>. | ||
/// | ||
/// After constructing as `Reducer::new(b, n)`, | ||
/// has operations to efficiently compute | ||
/// - `(a * b) / n` and `(a * b) % n` | ||
/// - `Reducer::new((a * b * b) % n, n)`, as long as `a * (n - 1) < R` | ||
#[derive(Clone, Copy, PartialEq, Eq, Debug)] | ||
pub(crate) struct Reducer<U> { | ||
// the multiplying factor `b in 0..n` | ||
num: U, | ||
// the modulus `n in 1..=R/2` | ||
div: U, | ||
// the precomputed quotient, `q = (b << K) / n` | ||
quo: U, | ||
// the remainder of that division, `r = (b << K) % n`, | ||
// (could always be recomputed as `(b << K) - q * n`, | ||
// but it is convenient to save) | ||
rem: U, | ||
} | ||
|
||
impl<U> Reducer<U> | ||
where | ||
U: Int + HInt, | ||
U::D: core::ops::Div<Output = U::D>, | ||
U::D: core::ops::Rem<Output = U::D>, | ||
{ | ||
/// Requires `num < div <= R/2`, will panic otherwise | ||
#[inline] | ||
pub fn new(num: U, div: U) -> Self { | ||
let _0 = U::ZERO; | ||
let _1 = U::ONE; | ||
|
||
assert!(num < div); | ||
assert!(div.wrapping_sub(_1).leading_zeros() >= 1); | ||
|
||
let bk = num.widen_hi(); | ||
let n = div.widen(); | ||
let quo = (bk / n).lo(); | ||
let rem = (bk % n).lo(); | ||
|
||
Self { num, div, quo, rem } | ||
} | ||
} | ||
|
||
impl<U> Reducer<U> | ||
where | ||
U: Int + HInt, | ||
U::D: Int, | ||
{ | ||
/// Return the unique pair `(quotient, remainder)` | ||
/// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < n` | ||
#[inline] | ||
pub fn mul_into_div_rem(&self, a: U) -> (U, U) { | ||
let (q, mut r) = self.mul_into_unnormalized_div_rem(a); | ||
// The unnormalized remainder is still guaranteed to be less than `2n`, so | ||
// one checked subtraction is sufficient. | ||
(q + U::cast_from(self.fixup(&mut r) as u8), r) | ||
} | ||
|
||
#[inline(always)] | ||
pub fn fixup(&self, x: &mut U) -> bool { | ||
x.checked_sub(self.div).map(|r| *x = r).is_some() | ||
} | ||
|
||
/// Return some pair `(quotient, remainder)` | ||
/// s.t. `a * b == quotient * n + remainder`, and `0 <= remainder < 2n` | ||
#[inline] | ||
pub fn mul_into_unnormalized_div_rem(&self, a: U) -> (U, U) { | ||
// General idea: Estimate the quotient `quotient = t in 0..a` s.t. | ||
// the remainder `ab - tn` is close to zero, so `t ~= ab / n` | ||
|
||
// Note: we use `R == 1 << U::BITS`, which means that | ||
// - wrapping arithmetic with `U` is modulo `R` | ||
// - all inputs are less than `R` | ||
|
||
// Range analysis: | ||
// | ||
// Using the definition of euclidean division on the two divisions done: | ||
// ``` | ||
// bR = qn + r, with 0 <= r < n | ||
// aq = tR + s, with 0 <= s < R | ||
// ``` | ||
let (_s, t) = a.widen_mul(self.quo).lo_hi(); | ||
// Then | ||
// ``` | ||
// (ab - tn)R | ||
// = abR - ntR | ||
// = a(qn + r) - n(aq - s) | ||
// = ar + ns | ||
// ``` | ||
#[cfg(debug_assertions)] | ||
{ | ||
assert!(t < a || (a == t && t.is_zero())); | ||
let ab_tn = a.widen_mul(self.num) - t.widen_mul(self.div); | ||
let ar_ns = a.widen_mul(self.rem) + _s.widen_mul(self.div); | ||
assert!(ab_tn.hi().is_zero()); | ||
assert!(ar_ns.lo().is_zero()); | ||
assert!(ab_tn.lo() == ar_ns.hi()); | ||
} | ||
// Since `s < R` and `r < n`, | ||
// ``` | ||
// 0 <= ns < nR | ||
// 0 <= ar < an | ||
// 0 <= (ab - tn) == (ar + ns)/R < n(1 + a/R) | ||
// ``` | ||
// Since `a < R` and we check on construction that `n <= R/2`, the result | ||
// is `0 <= ab - tn < R`, so it can be computed modulo `R` | ||
// even though the intermediate terms generally wrap. | ||
let ab = a.wrapping_mul(self.num); | ||
let tn = t.wrapping_mul(self.div); | ||
(t, ab.wrapping_sub(tn)) | ||
} | ||
|
||
/// Constructs a new reducer with `b` set to `(ab * b) % n` | ||
/// | ||
/// Requires `r * ab == ra * b`, where `r = bR % n`. | ||
#[inline(always)] | ||
fn with_scaled_num_rem(&self, ab: U, ra: U) -> Self { | ||
debug_assert!(ab.widen_mul(self.rem) == ra.widen_mul(self.num)); | ||
// The new factor `v = abb mod n`: | ||
let (_, v) = self.mul_into_div_rem(ab); | ||
|
||
// `rab = cn + d`, where `0 <= d < n` | ||
let (c, d) = self.mul_into_div_rem(ra); | ||
|
||
// We need `abbR = Xn + Y`: | ||
// abbR | ||
// = ab(qn + r) | ||
// = abqn + rab | ||
// = abqn + cn + d | ||
// = (abq + c)n + d | ||
|
||
Self { | ||
num: v, | ||
div: self.div, | ||
quo: self.quo.wrapping_mul(ab).wrapping_add(c), | ||
rem: d, | ||
} | ||
} | ||
|
||
/// Computes the reducer with the factor `b` set to `(a * b * b) % n` | ||
/// Requires that `a * (n - 1)` does not overflow. | ||
#[allow(dead_code)] | ||
#[inline] | ||
pub fn squared_with_scale(&self, a: U) -> Self { | ||
debug_assert!(a.widen_mul(self.div - U::ONE).hi().is_zero()); | ||
self.with_scaled_num_rem(a * self.num, a * self.rem) | ||
} | ||
|
||
/// Computes the reducer with the factor `b` set to `(b * b << s) % n` | ||
/// Requires that `(n - 1) << s` does not overflow. | ||
#[inline] | ||
pub fn squared_with_shift(&self, s: u32) -> Self { | ||
debug_assert!((self.div - U::ONE).leading_zeros() >= s); | ||
self.with_scaled_num_rem(self.num << s, self.rem << s) | ||
} | ||
} | ||
|
||
#[cfg(test)] | ||
mod test { | ||
use super::Reducer; | ||
|
||
#[test] | ||
fn u8_all() { | ||
for y in 1..=128_u8 { | ||
for r in 0..y { | ||
let m = Reducer::new(r, y); | ||
assert_eq!(m.quo, ((r as f32 * 256.0) / (y as f32)) as u8); | ||
for x in 0..=u8::MAX { | ||
let (quo, rem) = m.mul_into_div_rem(x); | ||
|
||
let q0 = x as u32 * r as u32 / y as u32; | ||
let r0 = x as u32 * r as u32 % y as u32; | ||
assert_eq!( | ||
(quo as u32, rem as u32), | ||
(q0, r0), | ||
"\n\ | ||
{x} * {r} = {xr}\n\ | ||
expected: = {q0} * {y} + {r0}\n\ | ||
returned: = {quo} * {y} + {rem} (== {})\n", | ||
quo as u32 * y as u32 + rem as u32, | ||
xr = x as u32 * r as u32, | ||
); | ||
} | ||
for s in 0..=y.leading_zeros() { | ||
assert_eq!( | ||
m.squared_with_shift(s), | ||
Reducer::new(((r << s) as u32 * r as u32 % y as u32) as u8, y) | ||
); | ||
} | ||
for a in 0..=u8::MAX { | ||
if a.checked_mul(y).is_some() { | ||
let abb = a as u32 * r as u32 * r as u32; | ||
assert_eq!( | ||
m.squared_with_scale(a), | ||
Reducer::new((abb % y as u32) as u8, y) | ||
); | ||
} else { | ||
break; | ||
} | ||
} | ||
for x0 in 0..=u8::MAX { | ||
if m.num == 0 || x0 as u32 * m.rem as u32 % m.num as u32 != 0 { | ||
continue; | ||
} | ||
let y0 = x0 as u32 * m.rem as u32 / m.num as u32; | ||
let Ok(y0) = u8::try_from(y0) else { continue }; | ||
|
||
assert_eq!( | ||
m.with_scaled_num_rem(x0, y0), | ||
Reducer::new((x0 as u32 * m.num as u32 % y as u32) as u8, y) | ||
); | ||
} | ||
} | ||
} | ||
} | ||
} |