Skip to content

Commit 835771b

Browse files
authored
Merge pull request RustPython#4252 from devonhollowood/accurate-hypot
Implement a more accurate hypot()
2 parents 426f9f6 + 96bf177 commit 835771b

File tree

2 files changed

+66
-17
lines changed

2 files changed

+66
-17
lines changed

Lib/test/test_math.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -833,8 +833,6 @@ def testHypot(self):
833833
@requires_IEEE_754
834834
@unittest.skipIf(HAVE_DOUBLE_ROUNDING,
835835
"hypot() loses accuracy on machines with double rounding")
836-
# TODO: RUSTPYTHON
837-
@unittest.expectedFailure
838836
def testHypotAccuracy(self):
839837
# Verify improved accuracy in cases that were known to be inaccurate.
840838
#

stdlib/src/math.rs

Lines changed: 66 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod math {
77
function::{ArgIntoFloat, ArgIterable, Either, OptionalArg, PosArgs},
88
identifier, PyObject, PyObjectRef, PyRef, PyResult, VirtualMachine,
99
};
10+
use itertools::Itertools;
1011
use num_bigint::BigInt;
1112
use num_rational::Ratio;
1213
use num_traits::{One, Signed, ToPrimitive, Zero};
@@ -283,24 +284,73 @@ mod math {
283284
if has_nan {
284285
return f64::NAN;
285286
}
286-
vector_norm(&coordinates, max)
287+
coordinates.sort_unstable_by(|x, y| x.total_cmp(y).reverse());
288+
vector_norm(&coordinates)
287289
}
288290

289-
fn vector_norm(v: &[f64], max: f64) -> f64 {
290-
if max == 0.0 || v.len() <= 1 {
291+
/// Implementation of accurate hypotenuse algorithm from Borges 2019.
292+
/// See https://arxiv.org/abs/1904.09481.
293+
/// This assumes that its arguments are positive finite and have been scaled to avoid overflow
294+
/// and underflow.
295+
fn accurate_hypot(max: f64, min: f64) -> f64 {
296+
if min <= max * (f64::EPSILON / 2.0).sqrt() {
291297
return max;
292298
}
293-
let mut csum = 1.0;
294-
let mut frac = 0.0;
295-
for &f in v {
296-
let f = f / max;
297-
let f = f * f;
298-
let old = csum;
299-
csum += f;
300-
// this seemingly redundant operation is to reduce float rounding errors/inaccuracy
301-
frac += (old - csum) + f;
302-
}
303-
max * f64::sqrt(csum - 1.0 + frac)
299+
let hypot = max.mul_add(max, min * min).sqrt();
300+
let hypot_sq = hypot * hypot;
301+
let max_sq = max * max;
302+
let correction = (-min).mul_add(min, hypot_sq - max_sq) + hypot.mul_add(hypot, -hypot_sq)
303+
- max.mul_add(max, -max_sq);
304+
hypot - correction / (2.0 * hypot)
305+
}
306+
307+
/// Calculates the norm of the vector given by `v`.
308+
/// `v` is assumed to be a list of non-negative finite floats, sorted in descending order.
309+
fn vector_norm(v: &[f64]) -> f64 {
310+
// Drop zeros from the vector.
311+
let zero_count = v.iter().rev().cloned().take_while(|x| *x == 0.0).count();
312+
let v = &v[..v.len() - zero_count];
313+
if v.is_empty() {
314+
return 0.0;
315+
}
316+
if v.len() == 1 {
317+
return v[0];
318+
}
319+
// Calculate scaling to avoid overflow / underflow.
320+
let max = *v.first().unwrap();
321+
let min = *v.last().unwrap();
322+
let scale = if max > (f64::MAX / v.len() as f64).sqrt() {
323+
max
324+
} else if min < f64::MIN_POSITIVE.sqrt() {
325+
// ^ This can be an `else if`, because if the max is near f64::MAX and the min is near
326+
// f64::MIN_POSITIVE, then the min is relatively unimportant and will be effectively
327+
// ignored.
328+
min
329+
} else {
330+
1.0
331+
};
332+
let mut norm = v
333+
.iter()
334+
.copied()
335+
.map(|x| x / scale)
336+
.reduce(accurate_hypot)
337+
.unwrap_or_default();
338+
if v.len() > 2 {
339+
// For larger lists of numbers, we can accumulate a rounding error, so a correction is
340+
// needed, similar to that in `accurate_hypot()`.
341+
// First, we estimate [sum of squares - norm^2], then we add the first-order
342+
// approximation of the square root of that to `norm`.
343+
let correction = v
344+
.iter()
345+
.copied()
346+
.map(|x| (x / scale).powi(2))
347+
.chain(std::iter::once(-norm * norm))
348+
// Pairwise summation of floats gives less rounding error than a naive sum.
349+
.tree_fold1(std::ops::Add::add)
350+
.expect("expected at least 1 element");
351+
norm = norm + correction / (2.0 * norm);
352+
}
353+
norm * scale
304354
}
305355

306356
#[pyfunction]
@@ -339,7 +389,8 @@ mod math {
339389
if has_nan {
340390
return Ok(f64::NAN);
341391
}
342-
Ok(vector_norm(&diffs, max))
392+
diffs.sort_unstable_by(|x, y| x.total_cmp(y).reverse());
393+
Ok(vector_norm(&diffs))
343394
}
344395

345396
#[pyfunction]

0 commit comments

Comments
 (0)