Skip to content

Commit d8ff7f0

Browse files
Add kde function and tests to RustPython statistics module
Function and tests are copied as is from CPython 3.13. Following is the result of diffing the output of `python -I whats_left.py` with and without the kde function added. ```sh ❯ diff with-kde.txt without-kde.txt 1464a1465 > statistics.kde 1465a1467 > statistics.pi ```
1 parent 96f47a4 commit d8ff7f0

File tree

2 files changed

+386
-0
lines changed

2 files changed

+386
-0
lines changed

Lib/statistics.py

Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
from itertools import groupby, repeat
137137
from bisect import bisect_left, bisect_right
138138
from math import hypot, sqrt, fabs, exp, erf, tau, log, fsum
139+
from math import pi, cos, sin, cosh, atan
139140
from operator import itemgetter
140141
from collections import Counter, namedtuple
141142

@@ -601,6 +602,218 @@ def multimode(data):
601602
return list(map(itemgetter(0), mode_items))
602603

603604

605+
def kde(data, h, kernel='normal', *, cumulative=False):
606+
"""Kernel Density Estimation: Create a continuous probability density
607+
function or cumulative distribution function from discrete samples.
608+
609+
The basic idea is to smooth the data using a kernel function
610+
to help draw inferences about a population from a sample.
611+
612+
The degree of smoothing is controlled by the scaling parameter h
613+
which is called the bandwidth. Smaller values emphasize local
614+
features while larger values give smoother results.
615+
616+
The kernel determines the relative weights of the sample data
617+
points. Generally, the choice of kernel shape does not matter
618+
as much as the more influential bandwidth smoothing parameter.
619+
620+
Kernels that give some weight to every sample point:
621+
622+
normal (gauss)
623+
logistic
624+
sigmoid
625+
626+
Kernels that only give weight to sample points within
627+
the bandwidth:
628+
629+
rectangular (uniform)
630+
triangular
631+
parabolic (epanechnikov)
632+
quartic (biweight)
633+
triweight
634+
cosine
635+
636+
If *cumulative* is true, will return a cumulative distribution function.
637+
638+
A StatisticsError will be raised if the data sequence is empty.
639+
640+
Example
641+
-------
642+
643+
Given a sample of six data points, construct a continuous
644+
function that estimates the underlying probability density:
645+
646+
>>> sample = [-2.1, -1.3, -0.4, 1.9, 5.1, 6.2]
647+
>>> f_hat = kde(sample, h=1.5)
648+
649+
Compute the area under the curve:
650+
651+
>>> area = sum(f_hat(x) for x in range(-20, 20))
652+
>>> round(area, 4)
653+
1.0
654+
655+
Plot the estimated probability density function at
656+
evenly spaced points from -6 to 10:
657+
658+
>>> for x in range(-6, 11):
659+
... density = f_hat(x)
660+
... plot = ' ' * int(density * 400) + 'x'
661+
... print(f'{x:2}: {density:.3f} {plot}')
662+
...
663+
-6: 0.002 x
664+
-5: 0.009 x
665+
-4: 0.031 x
666+
-3: 0.070 x
667+
-2: 0.111 x
668+
-1: 0.125 x
669+
0: 0.110 x
670+
1: 0.086 x
671+
2: 0.068 x
672+
3: 0.059 x
673+
4: 0.066 x
674+
5: 0.082 x
675+
6: 0.082 x
676+
7: 0.058 x
677+
8: 0.028 x
678+
9: 0.009 x
679+
10: 0.002 x
680+
681+
Estimate P(4.5 < X <= 7.5), the probability that a new sample value
682+
will be between 4.5 and 7.5:
683+
684+
>>> cdf = kde(sample, h=1.5, cumulative=True)
685+
>>> round(cdf(7.5) - cdf(4.5), 2)
686+
0.22
687+
688+
References
689+
----------
690+
691+
Kernel density estimation and its application:
692+
https://www.itm-conferences.org/articles/itmconf/pdf/2018/08/itmconf_sam2018_00037.pdf
693+
694+
Kernel functions in common use:
695+
https://en.wikipedia.org/wiki/Kernel_(statistics)#kernel_functions_in_common_use
696+
697+
Interactive graphical demonstration and exploration:
698+
https://demonstrations.wolfram.com/KernelDensityEstimation/
699+
700+
Kernel estimation of cumulative distribution function of a random variable with bounded support
701+
https://www.econstor.eu/bitstream/10419/207829/1/10.21307_stattrans-2016-037.pdf
702+
703+
"""
704+
705+
n = len(data)
706+
if not n:
707+
raise StatisticsError('Empty data sequence')
708+
709+
if not isinstance(data[0], (int, float)):
710+
raise TypeError('Data sequence must contain ints or floats')
711+
712+
if h <= 0.0:
713+
raise StatisticsError(f'Bandwidth h must be positive, not {h=!r}')
714+
715+
match kernel:
716+
717+
case 'normal' | 'gauss':
718+
sqrt2pi = sqrt(2 * pi)
719+
sqrt2 = sqrt(2)
720+
K = lambda t: exp(-1/2 * t * t) / sqrt2pi
721+
W = lambda t: 1/2 * (1.0 + erf(t / sqrt2))
722+
support = None
723+
724+
case 'logistic':
725+
# 1.0 / (exp(t) + 2.0 + exp(-t))
726+
K = lambda t: 1/2 / (1.0 + cosh(t))
727+
W = lambda t: 1.0 - 1.0 / (exp(t) + 1.0)
728+
support = None
729+
730+
case 'sigmoid':
731+
# (2/pi) / (exp(t) + exp(-t))
732+
c1 = 1 / pi
733+
c2 = 2 / pi
734+
K = lambda t: c1 / cosh(t)
735+
W = lambda t: c2 * atan(exp(t))
736+
support = None
737+
738+
case 'rectangular' | 'uniform':
739+
K = lambda t: 1/2
740+
W = lambda t: 1/2 * t + 1/2
741+
support = 1.0
742+
743+
case 'triangular':
744+
K = lambda t: 1.0 - abs(t)
745+
W = lambda t: t*t * (1/2 if t < 0.0 else -1/2) + t + 1/2
746+
support = 1.0
747+
748+
case 'parabolic' | 'epanechnikov':
749+
K = lambda t: 3/4 * (1.0 - t * t)
750+
W = lambda t: -1/4 * t**3 + 3/4 * t + 1/2
751+
support = 1.0
752+
753+
case 'quartic' | 'biweight':
754+
K = lambda t: 15/16 * (1.0 - t * t) ** 2
755+
W = lambda t: 3/16 * t**5 - 5/8 * t**3 + 15/16 * t + 1/2
756+
support = 1.0
757+
758+
case 'triweight':
759+
K = lambda t: 35/32 * (1.0 - t * t) ** 3
760+
W = lambda t: 35/32 * (-1/7*t**7 + 3/5*t**5 - t**3 + t) + 1/2
761+
support = 1.0
762+
763+
case 'cosine':
764+
c1 = pi / 4
765+
c2 = pi / 2
766+
K = lambda t: c1 * cos(c2 * t)
767+
W = lambda t: 1/2 * sin(c2 * t) + 1/2
768+
support = 1.0
769+
770+
case _:
771+
raise StatisticsError(f'Unknown kernel name: {kernel!r}')
772+
773+
if support is None:
774+
775+
def pdf(x):
776+
n = len(data)
777+
return sum(K((x - x_i) / h) for x_i in data) / (n * h)
778+
779+
def cdf(x):
780+
n = len(data)
781+
return sum(W((x - x_i) / h) for x_i in data) / n
782+
783+
else:
784+
785+
sample = sorted(data)
786+
bandwidth = h * support
787+
788+
def pdf(x):
789+
nonlocal n, sample
790+
if len(data) != n:
791+
sample = sorted(data)
792+
n = len(data)
793+
i = bisect_left(sample, x - bandwidth)
794+
j = bisect_right(sample, x + bandwidth)
795+
supported = sample[i : j]
796+
return sum(K((x - x_i) / h) for x_i in supported) / (n * h)
797+
798+
def cdf(x):
799+
nonlocal n, sample
800+
if len(data) != n:
801+
sample = sorted(data)
802+
n = len(data)
803+
i = bisect_left(sample, x - bandwidth)
804+
j = bisect_right(sample, x + bandwidth)
805+
supported = sample[i : j]
806+
return sum((W((x - x_i) / h) for x_i in supported), i) / n
807+
808+
if cumulative:
809+
cdf.__doc__ = f'CDF estimate with {h=!r} and {kernel=!r}'
810+
return cdf
811+
812+
else:
813+
pdf.__doc__ = f'PDF estimate with {h=!r} and {kernel=!r}'
814+
return pdf
815+
816+
604817
# Notes on methods for computing quantiles
605818
# ----------------------------------------
606819
#

0 commit comments

Comments
 (0)