Skip to content

Commit a39f46a

Browse files
authored
bpo-45876: Correctly rounded stdev() and pstdev() for the Decimal case (GH-29828)
1 parent 8a45ca5 commit a39f46a

File tree

2 files changed

+112
-22
lines changed

2 files changed

+112
-22
lines changed

Lib/statistics.py

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@
137137
from itertools import groupby, repeat
138138
from bisect import bisect_left, bisect_right
139139
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
140-
from operator import itemgetter, mul
140+
from operator import mul
141141
from collections import Counter, namedtuple
142142

143143
_SQRT2 = sqrt(2.0)
@@ -248,6 +248,28 @@ def _exact_ratio(x):
248248
249249
x is expected to be an int, Fraction, Decimal or float.
250250
"""
251+
252+
# XXX We should revisit whether using fractions to accumulate exact
253+
# ratios is the right way to go.
254+
255+
# The integer ratios for binary floats can have numerators or
256+
# denominators with over 300 decimal digits. The problem is more
257+
# acute with decimal floats where the the default decimal context
258+
# supports a huge range of exponents from Emin=-999999 to
259+
# Emax=999999. When expanded with as_integer_ratio(), numbers like
260+
# Decimal('3.14E+5000') and Decimal('3.14E-5000') have large
261+
# numerators or denominators that will slow computation.
262+
263+
# When the integer ratios are accumulated as fractions, the size
264+
# grows to cover the full range from the smallest magnitude to the
265+
# largest. For example, Fraction(3.14E+300) + Fraction(3.14E-300),
266+
# has a 616 digit numerator. Likewise,
267+
# Fraction(Decimal('3.14E+5000')) + Fraction(Decimal('3.14E-5000'))
268+
# has 10,003 digit numerator.
269+
270+
# This doesn't seem to have been problem in practice, but it is a
271+
# potential pitfall.
272+
251273
try:
252274
return x.as_integer_ratio()
253275
except AttributeError:
@@ -305,28 +327,60 @@ def _fail_neg(values, errmsg='negative value'):
305327
raise StatisticsError(errmsg)
306328
yield x
307329

308-
def _isqrt_frac_rto(n: int, m: int) -> float:
330+
331+
def _integer_sqrt_of_frac_rto(n: int, m: int) -> int:
309332
"""Square root of n/m, rounded to the nearest integer using round-to-odd."""
310333
# Reference: https://www.lri.fr/~melquion/doc/05-imacs17_1-expose.pdf
311334
a = math.isqrt(n // m)
312335
return a | (a*a*m != n)
313336

314-
# For 53 bit precision floats, the _sqrt_frac() shift is 109.
315-
_sqrt_shift: int = 2 * sys.float_info.mant_dig + 3
316337

317-
def _sqrt_frac(n: int, m: int) -> float:
338+
# For 53 bit precision floats, the bit width used in
339+
# _float_sqrt_of_frac() is 109.
340+
_sqrt_bit_width: int = 2 * sys.float_info.mant_dig + 3
341+
342+
343+
def _float_sqrt_of_frac(n: int, m: int) -> float:
318344
"""Square root of n/m as a float, correctly rounded."""
319345
# See principle and proof sketch at: https://bugs.python.org/msg407078
320-
q = (n.bit_length() - m.bit_length() - _sqrt_shift) // 2
346+
q = (n.bit_length() - m.bit_length() - _sqrt_bit_width) // 2
321347
if q >= 0:
322-
numerator = _isqrt_frac_rto(n, m << 2 * q) << q
348+
numerator = _integer_sqrt_of_frac_rto(n, m << 2 * q) << q
323349
denominator = 1
324350
else:
325-
numerator = _isqrt_frac_rto(n << -2 * q, m)
351+
numerator = _integer_sqrt_of_frac_rto(n << -2 * q, m)
326352
denominator = 1 << -q
327353
return numerator / denominator # Convert to float
328354

329355

356+
def _decimal_sqrt_of_frac(n: int, m: int) -> Decimal:
357+
"""Square root of n/m as a Decimal, correctly rounded."""
358+
# Premise: For decimal, computing (n/m).sqrt() can be off
359+
# by 1 ulp from the correctly rounded result.
360+
# Method: Check the result, moving up or down a step if needed.
361+
if n <= 0:
362+
if not n:
363+
return Decimal('0.0')
364+
n, m = -n, -m
365+
366+
root = (Decimal(n) / Decimal(m)).sqrt()
367+
nr, dr = root.as_integer_ratio()
368+
369+
plus = root.next_plus()
370+
np, dp = plus.as_integer_ratio()
371+
# test: n / m > ((root + plus) / 2) ** 2
372+
if 4 * n * (dr*dp)**2 > m * (dr*np + dp*nr)**2:
373+
return plus
374+
375+
minus = root.next_minus()
376+
nm, dm = minus.as_integer_ratio()
377+
# test: n / m < ((root + minus) / 2) ** 2
378+
if 4 * n * (dr*dm)**2 < m * (dr*nm + dm*nr)**2:
379+
return minus
380+
381+
return root
382+
383+
330384
# === Measures of central tendency (averages) ===
331385

332386
def mean(data):
@@ -869,7 +923,7 @@ def stdev(data, xbar=None):
869923
if hasattr(T, 'sqrt'):
870924
var = _convert(mss, T)
871925
return var.sqrt()
872-
return _sqrt_frac(mss.numerator, mss.denominator)
926+
return _float_sqrt_of_frac(mss.numerator, mss.denominator)
873927

874928

875929
def pstdev(data, mu=None):
@@ -888,10 +942,9 @@ def pstdev(data, mu=None):
888942
raise StatisticsError('pstdev requires at least one data point')
889943
T, ss = _ss(data, mu)
890944
mss = ss / n
891-
if hasattr(T, 'sqrt'):
892-
var = _convert(mss, T)
893-
return var.sqrt()
894-
return _sqrt_frac(mss.numerator, mss.denominator)
945+
if issubclass(T, Decimal):
946+
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
947+
return _float_sqrt_of_frac(mss.numerator, mss.denominator)
895948

896949

897950
# === Statistics for relations between two inputs ===

Lib/test/test_statistics.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2164,9 +2164,9 @@ def test_center_not_at_mean(self):
21642164

21652165
class TestSqrtHelpers(unittest.TestCase):
21662166

2167-
def test_isqrt_frac_rto(self):
2167+
def test_integer_sqrt_of_frac_rto(self):
21682168
for n, m in itertools.product(range(100), range(1, 1000)):
2169-
r = statistics._isqrt_frac_rto(n, m)
2169+
r = statistics._integer_sqrt_of_frac_rto(n, m)
21702170
self.assertIsInstance(r, int)
21712171
if r*r*m == n:
21722172
# Root is exact
@@ -2177,7 +2177,7 @@ def test_isqrt_frac_rto(self):
21772177
self.assertTrue(m * (r - 1)**2 < n < m * (r + 1)**2)
21782178

21792179
@requires_IEEE_754
2180-
def test_sqrt_frac(self):
2180+
def test_float_sqrt_of_frac(self):
21812181

21822182
def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
21832183
if not x:
@@ -2204,22 +2204,59 @@ def is_root_correctly_rounded(x: Fraction, root: float) -> bool:
22042204
denonimator: int = randrange(10 ** randrange(50)) + 1
22052205
with self.subTest(numerator=numerator, denonimator=denonimator):
22062206
x: Fraction = Fraction(numerator, denonimator)
2207-
root: float = statistics._sqrt_frac(numerator, denonimator)
2207+
root: float = statistics._float_sqrt_of_frac(numerator, denonimator)
22082208
self.assertTrue(is_root_correctly_rounded(x, root))
22092209

22102210
# Verify that corner cases and error handling match math.sqrt()
2211-
self.assertEqual(statistics._sqrt_frac(0, 1), 0.0)
2211+
self.assertEqual(statistics._float_sqrt_of_frac(0, 1), 0.0)
22122212
with self.assertRaises(ValueError):
2213-
statistics._sqrt_frac(-1, 1)
2213+
statistics._float_sqrt_of_frac(-1, 1)
22142214
with self.assertRaises(ValueError):
2215-
statistics._sqrt_frac(1, -1)
2215+
statistics._float_sqrt_of_frac(1, -1)
22162216

22172217
# Error handling for zero denominator matches that for Fraction(1, 0)
22182218
with self.assertRaises(ZeroDivisionError):
2219-
statistics._sqrt_frac(1, 0)
2219+
statistics._float_sqrt_of_frac(1, 0)
22202220

22212221
# The result is well defined if both inputs are negative
2222-
self.assertAlmostEqual(statistics._sqrt_frac(-2, -1), math.sqrt(2.0))
2222+
self.assertEqual(statistics._float_sqrt_of_frac(-2, -1), statistics._float_sqrt_of_frac(2, 1))
2223+
2224+
def test_decimal_sqrt_of_frac(self):
2225+
root: Decimal
2226+
numerator: int
2227+
denominator: int
2228+
2229+
for root, numerator, denominator in [
2230+
(Decimal('0.4481904599041192673635338663'), 200874688349065940678243576378, 1000000000000000000000000000000), # No adj
2231+
(Decimal('0.7924949131383786609961759598'), 628048187350206338833590574929, 1000000000000000000000000000000), # Adj up
2232+
(Decimal('0.8500554152289934068192208727'), 722594208960136395984391238251, 1000000000000000000000000000000), # Adj down
2233+
]:
2234+
with decimal.localcontext(decimal.DefaultContext):
2235+
self.assertEqual(statistics._decimal_sqrt_of_frac(numerator, denominator), root)
2236+
2237+
# Confirm expected root with a quad precision decimal computation
2238+
with decimal.localcontext(decimal.DefaultContext) as ctx:
2239+
ctx.prec *= 4
2240+
high_prec_ratio = Decimal(numerator) / Decimal(denominator)
2241+
ctx.rounding = decimal.ROUND_05UP
2242+
high_prec_root = high_prec_ratio.sqrt()
2243+
with decimal.localcontext(decimal.DefaultContext):
2244+
target_root = +high_prec_root
2245+
self.assertEqual(root, target_root)
2246+
2247+
# Verify that corner cases and error handling match Decimal.sqrt()
2248+
self.assertEqual(statistics._decimal_sqrt_of_frac(0, 1), 0.0)
2249+
with self.assertRaises(decimal.InvalidOperation):
2250+
statistics._decimal_sqrt_of_frac(-1, 1)
2251+
with self.assertRaises(decimal.InvalidOperation):
2252+
statistics._decimal_sqrt_of_frac(1, -1)
2253+
2254+
# Error handling for zero denominator matches that for Fraction(1, 0)
2255+
with self.assertRaises(ZeroDivisionError):
2256+
statistics._decimal_sqrt_of_frac(1, 0)
2257+
2258+
# The result is well defined if both inputs are negative
2259+
self.assertEqual(statistics._decimal_sqrt_of_frac(-2, -1), statistics._decimal_sqrt_of_frac(2, 1))
22232260

22242261

22252262
class TestStdev(VarianceStdevMixin, NumericTestCase):

0 commit comments

Comments
 (0)