Skip to content

bpo-46257: Convert statistics._ss() to a single pass algorithm #30403

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
bbd2da9
Merge pull request #1 from python/master
rhettinger Mar 16, 2021
74bdf1b
Merge branch 'master' of github.com:python/cpython
rhettinger Mar 22, 2021
6c53f1a
Merge branch 'master' of github.com:python/cpython
rhettinger Mar 22, 2021
a487c4f
.
rhettinger Mar 24, 2021
eb56423
.
rhettinger Mar 25, 2021
cc7ba06
.
rhettinger Mar 26, 2021
d024dd0
.
rhettinger Apr 22, 2021
b10f912
merge
rhettinger May 5, 2021
fb6744d
merge
rhettinger May 6, 2021
7f21a1c
Merge branch 'main' of github.com:python/cpython
rhettinger Aug 15, 2021
7da42d4
Merge branch 'main' of github.com:rhettinger/cpython
rhettinger Aug 25, 2021
e31757b
Merge branch 'main' of github.com:python/cpython
rhettinger Aug 31, 2021
f058a6f
Merge branch 'main' of github.com:python/cpython
rhettinger Aug 31, 2021
1fc29bd
Merge branch 'main' of github.com:python/cpython
rhettinger Sep 4, 2021
e5c0184
Merge branch 'main' of github.com:python/cpython
rhettinger Oct 30, 2021
3c86ec1
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 9, 2021
96675e4
Merge branch 'main' of github.com:rhettinger/cpython
rhettinger Nov 9, 2021
de558c6
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 9, 2021
418a07f
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 14, 2021
ea23a8b
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 21, 2021
ba248b7
Merge branch 'main' of github.com:python/cpython
rhettinger Nov 27, 2021
9bc1df1
Merge branch 'main' of github.com:python/cpython
rhettinger Dec 1, 2021
d4466ba
Merge branch 'main' of github.com:python/cpython
rhettinger Dec 1, 2021
a89f02e
Merge branch 'main' of github.com:python/cpython
rhettinger Dec 8, 2021
aae9a5f
Merge branch 'main' of github.com:python/cpython
rhettinger Dec 10, 2021
7ba634b
Merge branch 'main' of github.com:python/cpython
rhettinger Jan 1, 2022
0b54723
Add doctest and improve readability for move_to_end() example.
rhettinger Jan 3, 2022
6ce943f
Single pass sum of squares
rhettinger Jan 4, 2022
c8e2de7
Use len() to get the count
rhettinger Jan 4, 2022
45d83da
Avoid converting iterators to lists
rhettinger Jan 4, 2022
b1a89be
Neaten-up
rhettinger Jan 4, 2022
b4d2797
Add blurb
rhettinger Jan 4, 2022
712f648
Avoid touching collections.rst
rhettinger Jan 4, 2022
dc98276
Accumulate unsquared denominators
rhettinger Jan 4, 2022
6b2e8ca
Update Lib/statistics.py
rhettinger Jan 4, 2022
ae382ff
Make mean() single pass over iterators
rhettinger Jan 4, 2022
2e03c7a
Merge branch 'statistics_fast_ss' of github.com:rhettinger/cpython in…
rhettinger Jan 4, 2022
bbe6558
Update blurb to cover mean().
rhettinger Jan 4, 2022
326bce8
Move _ss() into the private utilities section.
rhettinger Jan 4, 2022
208abcd
Use defaultdict() instead of boundmethod
rhettinger Jan 5, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 43 additions & 57 deletions Lib/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@
from bisect import bisect_left, bisect_right
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
from operator import mul
from collections import Counter, namedtuple
from collections import Counter, namedtuple, defaultdict

_SQRT2 = sqrt(2.0)

Expand Down Expand Up @@ -202,6 +202,43 @@ def _sum(data):
return (T, total, count)


def _ss(data, c=None):
"""Return sum of square deviations of sequence data.

If ``c`` is None, the mean is calculated in one pass, and the deviations
from the mean are calculated in a second pass. Otherwise, deviations are
calculated from ``c`` as given. Use the second case with care, as it can
lead to garbage results.
"""
if c is not None:
T, total, count = _sum((d := x - c) * d for x in data)
return (T, total, count)
count = 0
sx_partials = defaultdict(int)
sxx_partials = defaultdict(int)
T = int
for typ, values in groupby(data, type):
T = _coerce(T, typ) # or raise TypeError
for n, d in map(_exact_ratio, values):
count += 1
sx_partials[d] += n
sxx_partials[d] += n * n
if not count:
total = Fraction(0)
elif None in sx_partials:
# The sum will be a NAN or INF. We can ignore all the finite
# partials, and just look at this special one.
total = sx_partials[None]
assert not _isfinite(total)
else:
sx = sum(Fraction(n, d) for d, n in sx_partials.items())
sxx = sum(Fraction(n, d*d) for d, n in sxx_partials.items())
# This formula has poor numeric properties for floats,
# but with fractions it is exact.
total = (count * sxx - sx * sx) / count
return (T, total, count)


def _isfinite(x):
try:
return x.is_finite() # Likely a Decimal.
Expand Down Expand Up @@ -399,13 +436,9 @@ def mean(data):

If ``data`` is empty, StatisticsError will be raised.
"""
if iter(data) is data:
data = list(data)
n = len(data)
T, total, n = _sum(data)
if n < 1:
raise StatisticsError('mean requires at least one data point')
T, total, count = _sum(data)
assert count == n
return _convert(total / n, T)


Expand Down Expand Up @@ -776,41 +809,6 @@ def quantiles(data, *, n=4, method='exclusive'):

# See http://mathworld.wolfram.com/Variance.html
# http://mathworld.wolfram.com/SampleVariance.html
# http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
#
# Under no circumstances use the so-called "computational formula for
# variance", as that is only suitable for hand calculations with a small
# amount of low-precision data. It has terrible numeric properties.
#
# See a comparison of three computational methods here:
# http://www.johndcook.com/blog/2008/09/26/comparing-three-methods-of-computing-standard-deviation/

def _ss(data, c=None):
"""Return sum of square deviations of sequence data.

If ``c`` is None, the mean is calculated in one pass, and the deviations
from the mean are calculated in a second pass. Otherwise, deviations are
calculated from ``c`` as given. Use the second case with care, as it can
lead to garbage results.
"""
if c is not None:
T, total, count = _sum((d := x - c) * d for x in data)
return (T, total)
T, total, count = _sum(data)
mean_n, mean_d = (total / count).as_integer_ratio()
partials = Counter()
for n, d in map(_exact_ratio, data):
diff_n = n * mean_d - d * mean_n
diff_d = d * mean_d
partials[diff_d * diff_d] += diff_n * diff_n
if None in partials:
# The sum will be a NAN or INF. We can ignore all the finite
# partials, and just look at this special one.
total = partials[None]
assert not _isfinite(total)
else:
total = sum(Fraction(n, d) for d, n in partials.items())
return (T, total)


def variance(data, xbar=None):
Expand Down Expand Up @@ -851,12 +849,9 @@ def variance(data, xbar=None):
Fraction(67, 108)

"""
if iter(data) is data:
data = list(data)
n = len(data)
T, ss, n = _ss(data, xbar)
if n < 2:
raise StatisticsError('variance requires at least two data points')
T, ss = _ss(data, xbar)
return _convert(ss / (n - 1), T)


Expand Down Expand Up @@ -895,12 +890,9 @@ def pvariance(data, mu=None):
Fraction(13, 72)

"""
if iter(data) is data:
data = list(data)
n = len(data)
T, ss, n = _ss(data, mu)
if n < 1:
raise StatisticsError('pvariance requires at least one data point')
T, ss = _ss(data, mu)
return _convert(ss / n, T)


Expand All @@ -913,12 +905,9 @@ def stdev(data, xbar=None):
1.0810874155219827

"""
if iter(data) is data:
data = list(data)
n = len(data)
T, ss, n = _ss(data, xbar)
if n < 2:
raise StatisticsError('stdev requires at least two data points')
T, ss = _ss(data, xbar)
mss = ss / (n - 1)
if issubclass(T, Decimal):
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
Expand All @@ -934,12 +923,9 @@ def pstdev(data, mu=None):
0.986893273527251

"""
if iter(data) is data:
data = list(data)
n = len(data)
T, ss, n = _ss(data, mu)
if n < 1:
raise StatisticsError('pstdev requires at least one data point')
T, ss = _ss(data, mu)
mss = ss / n
if issubclass(T, Decimal):
return _decimal_sqrt_of_frac(mss.numerator, mss.denominator)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Optimized the mean, variance, and stdev functions in the statistics module.
If the input is an iterator, it is consumed in a single pass rather than
eating memory by conversion to a list. The single pass algorithm is about
twice as fast as the previous two pass code.