Skip to content

statistics.covariance, statistics.correlation and statistics.linear_regression could accept decimal.Decimal inputs. #95130

@ghost

Description

Bug report

From the statistics documentation:

Unless explicitly noted, these functions support int, float, Decimal and Fraction.

However, this does not hold for the new (3.10) 'Statistics for relations between two inputs': covariance, correlation, and linear_regression

from decimal import Decimal
import statistics

x = [Decimal(number) for number in [1, 2, 3, 4, 5, 6, 7, 8, 9]]
y = [Decimal(number) for number in [1, 2, 3, 1, 2, 3, 1, 2, 3]]

statistics.covariance(x, y)
statistics.correlation(x, y)
statistics.linear_regression(x, y)

All three raise TypeError: unsupported operand type(s) for -: 'decimal.Decimal' and 'float' because some intermediate values are computed as float, which cannot be combined with Decimal in arithmetic operations.

(int and fraction can be combined with float in arithmetic operations, so this case does not raise a TypeError, but the result is always float, which may or may not be desirable)

Expected fix
I see three ways to address this:

  • change documentation to explicitly make correlation et al. float-only.
  • convert all inputs to float and always return float, like geometric_mean.
  • rewrite to return the same data type as inputs, like mean.

The first seems to go against the design rationale of the statistics module, but apparently no-one had any problem with float-only until now, so this may be fine. Otherwise the second would probably be easiest to implement. For the last option,

The original implementation #16813 almost worked for Decimal; by changing covariance to use statistics._sum instead of math.fsum and converting the result, all three would return Decimal for Decimal inputs (correlation returns float for Fractions because there's not really a sensible way to handle Fraction square roots):

def covariance(x, y, /):
    """ doc... """
    n = len(x)
    if len(y) != n:
        raise StatisticsError('covariance requires that both inputs have same number of data points')
    if n < 2:
        raise StatisticsError('covariance requires at least two data points')
    xbar = mean(x)
    ybar = mean(y)
    T, total, _ = _sum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
    return _convert(total / (n - 1), T)

However, some optimizations (#26135) have since been made based on the float-only assumption which would need to be reevaluated.
Applying the same fsum <-> _sum replacement to the optimized versions leads to ~10x slowdown on my machine, which makes it only ~3x faster than using the original implementation. There's probably a faster way to do it, and I haven't investigated potential accuracy differences, this just shows fully handling Decimal will have a significant but perhaps acceptable performance cost.

from collections import namedtuple
from statistics import StatisticsError, _sum, _coerce, _convert
from math import sqrt

def _sqrt(x, /):
    """ special-case square root for Decimal. Note Fraction and int are still converted to float. """
    try:
        return x.sqrt()
    except AttributeError:
        return sqrt(x)

def covariance(x, y, /):
    """ doc... """
    xT, xtotal, xn = _sum(x)
    yT, ytotal, yn = _sum(y)
    if xn != yn:
        raise StatisticsError('covariance requires that both inputs have same number of data points')
    if xn < 2:
        raise StatisticsError('covariance requires at least two data points')
    T = _coerce(xT, yT)
    xbar = _convert(xtotal / xn, T)
    ybar = _convert(ytotal / xn, T)
    _, sxy, _ = _sum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
    
    return _convert(sxy / (xn - 1), T)


def correlation(x, y, /):
    """ doc... """
    xT, xtotal, xn = _sum(x)
    yT, ytotal, yn = _sum(y)
    if xn != yn:
        raise StatisticsError('correlation requires that both inputs have same number of data points')
    if xn < 2:
        raise StatisticsError('correlation requires at least two data points')
    T = _coerce(xT, yT)
    xbar = _convert(xtotal / xn, T)
    ybar = _convert(ytotal / xn, T)
    _, sxy, _ = _sum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
    _, sxx, _ = _sum((d := xi - xbar) * d for xi in x)
    _, syy, _ = _sum((d := yi - ybar) * d for yi in y)
    try:
        return _convert(sxy / _sqrt(sxx * syy), T)
    except ZeroDivisionError:
        raise StatisticsError('at least one of the inputs is constant')


LinearRegression = namedtuple('LinearRegression', ('slope', 'intercept'))


def linear_regression(x, y, /, *, proportional=False):
    """ doc... """
    xT, xtotal, xn = _sum(x)
    yT, ytotal, yn = _sum(y)
    if xn != yn:
        raise StatisticsError('correlation requires that both inputs have same number of data points')
    if xn < 2:
        raise StatisticsError('correlation requires at least two data points')
    T = _coerce(xT, yT)

    if proportional:
        _, sxy, _ = _sum(xi * yi for xi, yi in zip(x, y))
        _, sxx, _ = _sum(xi * xi for xi in x)
    else:
        xbar = _convert(xtotal / xn, T)
        ybar = _convert(ytotal / xn, T)
        _, sxy, _ = _sum((xi - xbar) * (yi - ybar) for xi, yi in zip(x, y))
        _, sxx, _ = _sum((d := xi - xbar) * d for xi in x)
    try:
        slope = _convert(sxy / sxx, T)   # equivalent to:  covariance(x, y) / variance(x)
    except ZeroDivisionError:
        raise StatisticsError('x is constant')
    intercept = _convert(0.0 if proportional else ybar - slope * xbar, T)
    return LinearRegression(slope=slope, intercept=intercept)

Environment

Bonus Enhancement
If rewriting, maybe make covariance, correlation, and linear_regression accept iterators too, as all other statistics functions do even when no single-pass algorithm is available, e.g. harmonic_mean:

def harmonic_mean(data, weights=None):
    if iter(data) is data:
        data = list(data)
    ...

Metadata

Metadata

Assignees

Labels

docsDocumentation in the Doc dirstdlibPython modules in the Lib dirtype-featureA feature request or enhancement

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions