Skip to content

Implement a more accurate hypot() #4252

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions Lib/test/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,8 +833,6 @@ def testHypot(self):
@requires_IEEE_754
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
"hypot() loses accuracy on machines with double rounding")
# TODO: RUSTPYTHON
@unittest.expectedFailure
def testHypotAccuracy(self):
# Verify improved accuracy in cases that were known to be inaccurate.
#
Expand Down
81 changes: 66 additions & 15 deletions stdlib/src/math.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod math {
function::{ArgIntoFloat, ArgIterable, Either, OptionalArg, PosArgs},
identifier, PyObject, PyObjectRef, PyRef, PyResult, VirtualMachine,
};
use itertools::Itertools;
use num_bigint::BigInt;
use num_rational::Ratio;
use num_traits::{One, Signed, ToPrimitive, Zero};
Expand Down Expand Up @@ -283,24 +284,73 @@ mod math {
if has_nan {
return f64::NAN;
}
vector_norm(&coordinates, max)
coordinates.sort_unstable_by(|x, y| x.total_cmp(y).reverse());
vector_norm(&coordinates)
}

fn vector_norm(v: &[f64], max: f64) -> f64 {
if max == 0.0 || v.len() <= 1 {
/// Implementation of accurate hypotenuse algorithm from Borges 2019.
/// See https://arxiv.org/abs/1904.09481.
/// This assumes that its arguments are positive finite and have been scaled to avoid overflow
/// and underflow.
fn accurate_hypot(max: f64, min: f64) -> f64 {
if min <= max * (f64::EPSILON / 2.0).sqrt() {
return max;
}
let mut csum = 1.0;
let mut frac = 0.0;
for &f in v {
let f = f / max;
let f = f * f;
let old = csum;
csum += f;
// this seemingly redundant operation is to reduce float rounding errors/inaccuracy
frac += (old - csum) + f;
}
max * f64::sqrt(csum - 1.0 + frac)
let hypot = max.mul_add(max, min * min).sqrt();
let hypot_sq = hypot * hypot;
let max_sq = max * max;
let correction = (-min).mul_add(min, hypot_sq - max_sq) + hypot.mul_add(hypot, -hypot_sq)
- max.mul_add(max, -max_sq);
hypot - correction / (2.0 * hypot)
}

/// Calculates the norm of the vector given by `v`.
/// `v` is assumed to be a list of non-negative finite floats, sorted in descending order.
fn vector_norm(v: &[f64]) -> f64 {
// Drop zeros from the vector.
let zero_count = v.iter().rev().cloned().take_while(|x| *x == 0.0).count();
let v = &v[..v.len() - zero_count];
if v.is_empty() {
return 0.0;
}
if v.len() == 1 {
return v[0];
}
// Calculate scaling to avoid overflow / underflow.
let max = *v.first().unwrap();
let min = *v.last().unwrap();
let scale = if max > (f64::MAX / v.len() as f64).sqrt() {
max
} else if min < f64::MIN_POSITIVE.sqrt() {
// ^ This can be an `else if`, because if the max is near f64::MAX and the min is near
// f64::MIN_POSITIVE, then the min is relatively unimportant and will be effectively
// ignored.
min
} else {
1.0
};
let mut norm = v
.iter()
.copied()
.map(|x| x / scale)
.reduce(accurate_hypot)
.unwrap_or_default();
if v.len() > 2 {
// For larger lists of numbers, we can accumulate a rounding error, so a correction is
// needed, similar to that in `accurate_hypot()`.
// First, we estimate [sum of squares - norm^2], then we add the first-order
// approximation of the square root of that to `norm`.
let correction = v
.iter()
.copied()
.map(|x| (x / scale).powi(2))
.chain(std::iter::once(-norm * norm))
// Pairwise summation of floats gives less rounding error than a naive sum.
.tree_fold1(std::ops::Add::add)
.expect("expected at least 1 element");
norm = norm + correction / (2.0 * norm);
}
norm * scale
}

#[pyfunction]
Expand Down Expand Up @@ -339,7 +389,8 @@ mod math {
if has_nan {
return Ok(f64::NAN);
}
Ok(vector_norm(&diffs, max))
diffs.sort_unstable_by(|x, y| x.total_cmp(y).reverse());
Ok(vector_norm(&diffs))
}

#[pyfunction]
Expand Down