Skip to content

Commit 1c123db

Browse files
committed
MNT: Refactor default violin KDE estimator
Move the default KDE estimator from a private definition in `violinplot()` into `violin_stats()`. This makes it easier to test and debug violin_stats() as we don't have to explicitly provide a KDE method. It also becomes logically simpler, because `violinplot()` is now only `violin_stats()` + `violin()`.
1 parent dfc888c commit 1c123db

File tree

2 files changed

+33
-18
lines changed

2 files changed

+33
-18
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8878,18 +8878,8 @@ def violinplot(self, dataset, positions=None, vert=None,
88788878
.Axes.violin : Draw a violin from pre-computed statistics.
88798879
boxplot : Draw a box and whisker plot.
88808880
"""
8881-
8882-
def _kde_method(X, coords):
8883-
# Unpack in case of e.g. Pandas or xarray object
8884-
X = cbook._unpack_to_numpy(X)
8885-
# fallback gracefully if the vector contains only one value
8886-
if np.all(X[0] == X):
8887-
return (X[0] == coords).astype(float)
8888-
kde = mlab.GaussianKDE(X, bw_method)
8889-
return kde.evaluate(coords)
8890-
8891-
vpstats = cbook.violin_stats(dataset, _kde_method, points=points,
8892-
quantiles=quantiles)
8881+
vpstats = cbook.violin_stats(dataset, ("GaussianKDE", bw_method),
8882+
points=points, quantiles=quantiles)
88938883
return self.violin(vpstats, positions=positions, vert=vert,
88948884
orientation=orientation, widths=widths,
88958885
showmeans=showmeans, showextrema=showextrema,

lib/matplotlib/cbook.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from numpy import VisibleDeprecationWarning
3030

3131
import matplotlib
32-
from matplotlib import _api, _c_internal_utils
32+
from matplotlib import _api, _c_internal_utils, mlab
3333

3434

3535
class _ExceptionInfo:
@@ -1430,7 +1430,7 @@ def _reshape_2D(X, name):
14301430
return result
14311431

14321432

1433-
def violin_stats(X, method, points=100, quantiles=None):
1433+
def violin_stats(X, method=("GaussianKDE", "scott"), points=100, quantiles=None):
14341434
"""
14351435
Return a list of dictionaries of data which can be used to draw a series
14361436
of violin plots.
@@ -1449,11 +1449,23 @@ def violin_stats(X, method, points=100, quantiles=None):
14491449
Sample data that will be used to produce the gaussian kernel density
14501450
estimates. Must have 2 or fewer dimensions.
14511451
1452-
method : callable
1452+
method : (name, bw_method) or callable,
14531453
The method used to calculate the kernel density estimate for each
1454-
column of data. When called via ``method(v, coords)``, it should
1455-
return a vector of the values of the KDE evaluated at the values
1456-
specified in coords.
1454+
column of data. Valid values:
1455+
1456+
- a tuple of the form ``(name, bw_method)`` where *name* currently must
1457+
always be ``"GaussianKDE"`` and *bw_method* is the method used to
1458+
calculate the estimator bandwidth. Supported values are 'scott',
1459+
'silverman' or a float or a callable. If a float, this will be used
1460+
directly as `!kde.factor`. If a callable, it should take a
1461+
`matplotlib.mlab.GaussianKDE` instance as its only parameter and
1462+
return a float.
1463+
1464+
- a callable with the signature ::
1465+
1466+
def method(data: ndarray, coords: ndarray) -> ndarray
1467+
1468+
It should return the KDE of *data* evaluated at *coords*.
14571469
14581470
points : int, default: 100
14591471
Defines the number of points to evaluate each of the gaussian kernel
@@ -1481,6 +1493,19 @@ def violin_stats(X, method, points=100, quantiles=None):
14811493
- max: The maximum value for this column of data.
14821494
- quantiles: The quantile values for this column of data.
14831495
"""
1496+
if isinstance(method, tuple):
1497+
name, bw_method = method
1498+
if name != "GaussianKDE":
1499+
raise ValueError(f"Unknown method {name!r} for violin_stats")
1500+
1501+
def _kde_method(x, coords):
1502+
# fallback gracefully if the vector contains only one value
1503+
if np.all(x[0] == x):
1504+
return (x[0] == coords).astype(float)
1505+
kde = mlab.GaussianKDE(x, bw_method)
1506+
return kde.evaluate(coords)
1507+
1508+
method = _kde_method
14841509

14851510
# List of dictionaries describing each of the violins.
14861511
vpstats = []

0 commit comments

Comments
 (0)