@@ -7,6 +7,7 @@ mod math {
7
7
function:: { ArgIntoFloat , ArgIterable , Either , OptionalArg , PosArgs } ,
8
8
identifier, PyObject , PyObjectRef , PyRef , PyResult , VirtualMachine ,
9
9
} ;
10
+ use itertools:: Itertools ;
10
11
use num_bigint:: BigInt ;
11
12
use num_rational:: Ratio ;
12
13
use num_traits:: { One , Signed , ToPrimitive , Zero } ;
@@ -283,24 +284,73 @@ mod math {
283
284
if has_nan {
284
285
return f64:: NAN ;
285
286
}
286
- vector_norm ( & coordinates, max)
287
+ coordinates. sort_unstable_by ( |x, y| x. total_cmp ( y) . reverse ( ) ) ;
288
+ vector_norm ( & coordinates)
287
289
}
288
290
289
- fn vector_norm ( v : & [ f64 ] , max : f64 ) -> f64 {
290
- if max == 0.0 || v. len ( ) <= 1 {
291
+ /// Implementation of accurate hypotenuse algorithm from Borges 2019.
292
+ /// See https://arxiv.org/abs/1904.09481.
293
+ /// This assumes that its arguments are positive finite and have been scaled to avoid overflow
294
+ /// and underflow.
295
+ fn accurate_hypot ( max : f64 , min : f64 ) -> f64 {
296
+ if min <= max * ( f64:: EPSILON / 2.0 ) . sqrt ( ) {
291
297
return max;
292
298
}
293
- let mut csum = 1.0 ;
294
- let mut frac = 0.0 ;
295
- for & f in v {
296
- let f = f / max;
297
- let f = f * f;
298
- let old = csum;
299
- csum += f;
300
- // this seemingly redundant operation is to reduce float rounding errors/inaccuracy
301
- frac += ( old - csum) + f;
302
- }
303
- max * f64:: sqrt ( csum - 1.0 + frac)
299
+ let hypot = max. mul_add ( max, min * min) . sqrt ( ) ;
300
+ let hypot_sq = hypot * hypot;
301
+ let max_sq = max * max;
302
+ let correction = ( -min) . mul_add ( min, hypot_sq - max_sq) + hypot. mul_add ( hypot, -hypot_sq)
303
+ - max. mul_add ( max, -max_sq) ;
304
+ hypot - correction / ( 2.0 * hypot)
305
+ }
306
+
307
+ /// Calculates the norm of the vector given by `v`.
308
+ /// `v` is assumed to be a list of non-negative finite floats, sorted in descending order.
309
+ fn vector_norm ( v : & [ f64 ] ) -> f64 {
310
+ // Drop zeros from the vector.
311
+ let zero_count = v. iter ( ) . rev ( ) . cloned ( ) . take_while ( |x| * x == 0.0 ) . count ( ) ;
312
+ let v = & v[ ..v. len ( ) - zero_count] ;
313
+ if v. is_empty ( ) {
314
+ return 0.0 ;
315
+ }
316
+ if v. len ( ) == 1 {
317
+ return v[ 0 ] ;
318
+ }
319
+ // Calculate scaling to avoid overflow / underflow.
320
+ let max = * v. first ( ) . unwrap ( ) ;
321
+ let min = * v. last ( ) . unwrap ( ) ;
322
+ let scale = if max > ( f64:: MAX / v. len ( ) as f64 ) . sqrt ( ) {
323
+ max
324
+ } else if min < f64:: MIN_POSITIVE . sqrt ( ) {
325
+ // ^ This can be an `else if`, because if the max is near f64::MAX and the min is near
326
+ // f64::MIN_POSITIVE, then the min is relatively unimportant and will be effectively
327
+ // ignored.
328
+ min
329
+ } else {
330
+ 1.0
331
+ } ;
332
+ let mut norm = v
333
+ . iter ( )
334
+ . copied ( )
335
+ . map ( |x| x / scale)
336
+ . reduce ( accurate_hypot)
337
+ . unwrap_or_default ( ) ;
338
+ if v. len ( ) > 2 {
339
+ // For larger lists of numbers, we can accumulate a rounding error, so a correction is
340
+ // needed, similar to that in `accurate_hypot()`.
341
+ // First, we estimate [sum of squares - norm^2], then we add the first-order
342
+ // approximation of the square root of that to `norm`.
343
+ let correction = v
344
+ . iter ( )
345
+ . copied ( )
346
+ . map ( |x| ( x / scale) . powi ( 2 ) )
347
+ . chain ( std:: iter:: once ( -norm * norm) )
348
+ // Pairwise summation of floats gives less rounding error than a naive sum.
349
+ . tree_fold1 ( std:: ops:: Add :: add)
350
+ . expect ( "expected at least 1 element" ) ;
351
+ norm = norm + correction / ( 2.0 * norm) ;
352
+ }
353
+ norm * scale
304
354
}
305
355
306
356
#[ pyfunction]
@@ -339,7 +389,8 @@ mod math {
339
389
if has_nan {
340
390
return Ok ( f64:: NAN ) ;
341
391
}
342
- Ok ( vector_norm ( & diffs, max) )
392
+ diffs. sort_unstable_by ( |x, y| x. total_cmp ( y) . reverse ( ) ) ;
393
+ Ok ( vector_norm ( & diffs) )
343
394
}
344
395
345
396
#[ pyfunction]
0 commit comments