Skip to content

Response to Feature Request: draw percentiles in violinplot #8532 #8585

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 77 additions & 5 deletions lib/matplotlib/axes/_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,26 @@ def _plot_args_replacer(args, data):
"multiple plotting calls instead.")


class ViolinStatFunc:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this should be private

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarification: maybe I should not put this class in this file. I meant for it to be public so that user can easily enter input. They have more than just the callable: e.g. percentile requires 1) an array of data and 2) and int. While the array data is drawn from main input , we need to put the integer argument in this object. Furthermore, each function drawn on the violin has to have a corresponding artist object returned by the function _Axes.violin() so we need an "alias" there as well. While each function input alias should be unique, I tried to be defensive by dealing with duplicating alias as discussed below. In short, this is a wrapper object that specifies everything that a callable to be drawn on the violin should know (apart from the array data input, which is drawn elsewhere). Happy to discuss/modify more if required.

"""
The :class:`ViolinStatFunc` contains:
1) a callable whose first argument is compulsory and is a 1-d list of data
that is used to plot the violin. This first argument is not required to be
specified.
2) an alias for this callable. When violinplot outputs the dictionary of
artists, this alias is used to identify the artist object corresponding to
this callable
3) a list of additional arguments. This list does not contain the
aforementioned compulsory 1-d list of data.
"""
def __init__(self, func_callable, **kargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using "kwargs" instead of "kargs" would be more consistent with the rest of the MPL code base

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted

self.func_callable = func_callable
self.alias = kargs.pop('alias', func_callable.__name__)
self.optional_args = kargs.pop('args', [])
if not isinstance(self.optional_args, list):
Copy link
Member

@phobson phobson May 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any iterable should be fine. Or, at the very least, tuples should be valid as well.

So something like:

if np.isscalar(self.optional_args):
    self.optional_args = [self.optional_args]

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noted.

raise ValueError('args has to be a list')


# The axes module contains all the wrappers to plotting functions.
# All the other methods should go in the _AxesBase class.

Expand Down Expand Up @@ -7277,7 +7297,7 @@ def matshow(self, Z, **kwargs):
@_preprocess_data(replace_names=["dataset"], label_namer=None)
def violinplot(self, dataset, positions=None, vert=True, widths=0.5,
showmeans=False, showextrema=True, showmedians=False,
points=100, bw_method=None):
points=100, bw_method=None, statistics_function_list=[]):
"""
Make a violin plot.

Expand Down Expand Up @@ -7324,6 +7344,13 @@ def violinplot(self, dataset, positions=None, vert=True, widths=0.5,
callable, it should take a `GaussianKDE` instance as its only
parameter and return a scalar. If None (default), 'scott' is used.

statistics_function_list: a list of callable or ViolinStatFunc. The
element of this list can be any custom summary statistics to be
displayed on the voilin plot (with one constraint that the first
argument of these function has to be the input data of violin plot
i.e. dataset if dataset is 1-d or an element of dataset if dataset is
2-d)

Returns
-------

Expand Down Expand Up @@ -7369,12 +7396,44 @@ def _kde_method(X, coords):
kde = mlab.GaussianKDE(X, bw_method)
return kde.evaluate(coords)

vpstats = cbook.violin_stats(dataset, _kde_method, points=points)
return self.violin(vpstats, positions=positions, vert=vert,
def _resolve_duplicate_alias(func_obj_list):
unique_alias_set = set()
for func_obj in func_obj_list:
while func_obj.alias in unique_alias_set:
func_obj.alias += 'x'
unique_alias_set.add(func_obj.alias)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this need to return unique_alias_set?

Copy link
Author

@ghost ghost May 11, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. unique_alias_set keeps track of all alias so far. The actual alias (modified so that they are unique) is in the ViolinStatFunc object itself. I think this defensive programming mechanism helps safe-guarding against user inputs where multiple input functions have repeating aliases. These aliases later on become the keys in the dict of artist object returned by violin function. While never experimented with it myself, I assume these artist object allows user to modify colors, shape and other visual effect of the line drawn on the violin. Thoughts?

Copy link
Member

@phobson phobson May 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unique_alias_set isn't used or defined anywhere else, so I'm not sure why we even need the aliases. Why can't the user just pass a list of functions that accept an array as input and returns an array or scalar, and draw lines at each value? For instance, the user would provide:

from functools import partial
stat_fxn = [partial(np.percentile, q=[5, 25, 50, 75, 95])]
# OR
stat_fxns = [np.mean, np.median, partial(np.percentile, q=95)]

## Then somewhere in the new code, we'd do something like:
results = np.flatten([stat(data) for stat in stat_fxns])  # not sure that flatten is the right function

should work.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general, I guess I'm saying I'd like to see an API that let's the users work with familiar functions, rather than having to the learn the nuances of a new class

Copy link
Author

@ghost ghost May 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah. I see where you are coming from. Will remove this new class. I'm only worried about the very final output i.e. the artist dictionary output of _Axes.violin(). Where I'm coming from is that the _Axes.violin() outputs a dictionary of artist, and we need unique dictionary keys to represent the artist corresponding to the custom statistics function.
These artist objects allow user to change color, shape of the line etc... and the dictionary requires a unique key i.e. my current implemented alias. We have three ways of going around this:

  1. (non-user friendly) make user input a list of dictionary where each dictionary has a callable and an alias and constraint that all aliases in the list of dict are unique
  2. (slightly better) make user input the same list of dictionary but alias doesn't need to be unique, we modify the duplicating alias internally (as I did above with the resolve_duplicate_aliases and let them know we modify some aliases.
  3. (probably fits where you are going with) don't ask for alias at all, and we make our own alias internally in the order that user input, but we still have to let them know the aliases.
    Let me know what you think. I'm open for suggestions.


violin_stat_func_obj_list = []
for func in statistics_function_list:
if not isinstance(func, ViolinStatFunc):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if _ViolinStatFunc is private, then we can only deal with callables and simplify this logic.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed above.

if callable(func):
violin_stat_func_obj_list.append(ViolinStatFunc(func))
else:
raise ValueError(
'Optional argument has to be a callable' +
'or a ViolinStatFunc object')
else:
violin_stat_func_obj_list.append(func)

custom_stat_alias_list = [func_obj.alias for func_obj in
violin_stat_func_obj_list]
if len(custom_stat_alias_list) > len(set(custom_stat_alias_list)):
_resolve_duplicate_alias(violin_stat_func_obj_list)
# remake alias list based on updated unique aliases
custom_stat_alias_list = [func_obj.alias for func_obj in
violin_stat_func_obj_list]

vpstats, custom_stat_vals = \
cbook.violin_stats(dataset, _kde_method,
violin_stat_func_obj_list, points=points)

return self.violin(vpstats, custom_stat_vals, custom_stat_alias_list,
positions=positions, vert=vert,
widths=widths, showmeans=showmeans,
showextrema=showextrema, showmedians=showmedians)

def violin(self, vpstats, positions=None, vert=True, widths=0.5,
def violin(self, vpstats, custom_stat_vals, custom_stat_alias_list,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are the custom stat-related values required parameters now? They should be be optional, IMO

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, they should be optional

positions=None, vert=True, widths=0.5,
showmeans=False, showextrema=True, showmedians=False):
"""Drawing function for violin plots.

Expand Down Expand Up @@ -7511,7 +7570,9 @@ def violin(self, vpstats, positions=None, vert=True, widths=0.5,

# Render violins
bodies = []
for stats, pos, width in zip(vpstats, positions, widths):
custom_vals = {}
for stats, pos, width, stat_val_dict in zip(vpstats, positions,
widths, custom_stat_vals):
# The 0.5 factor reflects the fact that we plot from v-p to
# v+p
vals = np.array(stats['vals'])
Expand All @@ -7525,6 +7586,11 @@ def violin(self, vpstats, positions=None, vert=True, widths=0.5,
mins.append(stats['min'])
maxes.append(stats['max'])
medians.append(stats['median'])
for alias in custom_stat_alias_list:
if alias not in custom_vals:
custom_vals[alias] = []
custom_vals[alias].append(stat_val_dict[alias])

artists['bodies'] = bodies

# Render means
Expand All @@ -7547,6 +7613,12 @@ def violin(self, vpstats, positions=None, vert=True, widths=0.5,
pmins,
pmaxes,
colors=edgecolor)
# Render custom statistics
for alias in custom_stat_alias_list:
artists['custom_' + alias] = perp_lines(custom_vals[alias],
pmins,
pmaxes,
colors=edgecolor)

return artists

Expand Down
46 changes: 28 additions & 18 deletions lib/matplotlib/cbook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2001,13 +2001,13 @@ def _reshape_2D(X, name):
raise ValueError("{} must have 2 or fewer dimensions".format(name))


def violin_stats(X, method, points=100):
def violin_stats(X, method, custom_stat_func_obj, points=100):
"""
Returns a list of dictionaries of data which can be used to draw a series
of violin plots. See the `Returns` section below to view the required keys
of the dictionary. Users can skip this function and pass a user-defined set
of dictionaries to the `axes.vplot` method instead of using MPL to do the
calculations.
of the dictionary. Users can skip this function and pass a user-defined
set of dictionaries to the `axes.vplot` method instead of using MPL to do
the calculations.

Parameters
----------
Expand All @@ -2021,35 +2021,41 @@ def violin_stats(X, method, points=100):
return a vector of the values of the KDE evaluated at the values
specified in coords.

custom_stat_func_obj : a list of ViolinStatFunc object, each containing a
custom statistics to be drawn on the violin plot

points : scalar, default = 100
Defines the number of points to evaluate each of the gaussian kernel
density estimates at.
Defines the number of points to evaluate each of the gaussian kernel
density estimates at.

Returns
-------

A list of dictionaries containing the results for each column of data.
The dictionaries contain at least the following:
Two lists of dictionaries containing the results for each column of data.
The first list of dictionaries contain at least the following:

- coords: A list of scalars containing the coordinates this particular
kernel density estimate was evaluated at.
- vals: A list of scalars containing the values of the kernel density
estimate at each of the coordinates given in `coords`.
- mean: The mean value for this column of data.
- median: The median value for this column of data.
- min: The minimum value for this column of data.
- max: The maximum value for this column of data.
- coords: A list of scalars containing the coordinates this particular
kernel density estimate was evaluated at.
- vals: A list of scalars containing the values of the kernel density
estimate at each of the coordinates given in `coords`.
- mean: The mean value for this column of data.
- median: The median value for this column of data.
- min: The minimum value for this column of data.
- max: The maximum value for this column of data.
The second list of dictionaries contains the results for each column of
data computed from the custom each of the statistics function
"""

# List of dictionaries describing each of the violins.
vpstats = []
custom_stat_vals = []

# Want X to be a list of data sequences
X = _reshape_2D(X, "X")

for x in X:
# Dictionary of results for this distribution
stats = {}
stats2 = {}

# Calculate basic stats for the distribution
min_val = np.min(x)
Expand All @@ -2065,11 +2071,15 @@ def violin_stats(X, method, points=100):
stats['median'] = np.median(x)
stats['min'] = min_val
stats['max'] = max_val
for func_obj in custom_stat_func_obj:
stats2[func_obj.alias] = \
func_obj.func_callable(x, *func_obj.optional_args)

# Append to output
vpstats.append(stats)
custom_stat_vals.append(stats2)

return vpstats
return vpstats, custom_stat_vals


class _NestedClassGetter(object):
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
82 changes: 78 additions & 4 deletions lib/matplotlib/tests/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2329,7 +2329,8 @@ def test_vert_violinplot_baseline():
showmedians=0, data=data)


@image_comparison(baseline_images=['violinplot_vert_showmeans'],
@image_comparison(baseline_images=['violinplot_vert_showmeans',
'violinplot_vert_showmeans'],
extensions=['png'])
def test_vert_violinplot_showmeans():
ax = plt.axes()
Expand All @@ -2338,6 +2339,11 @@ def test_vert_violinplot_showmeans():
data = [np.random.normal(size=100) for i in range(4)]
ax.violinplot(data, positions=range(4), showmeans=1, showextrema=0,
showmedians=0)
fig, ax = plt.subplots()
ax = plt.axes()
ax.violinplot(data, statistics_function_list=[np.mean],
positions=range(4), showmeans=0, showextrema=0,
showmedians=0)


@image_comparison(baseline_images=['violinplot_vert_showextrema'],
Expand All @@ -2351,7 +2357,8 @@ def test_vert_violinplot_showextrema():
showmedians=0)


@image_comparison(baseline_images=['violinplot_vert_showmedians'],
@image_comparison(baseline_images=['violinplot_vert_showmedians',
'violinplot_vert_showmedians'],
extensions=['png'])
def test_vert_violinplot_showmedians():
ax = plt.axes()
Expand All @@ -2360,6 +2367,11 @@ def test_vert_violinplot_showmedians():
data = [np.random.normal(size=100) for i in range(4)]
ax.violinplot(data, positions=range(4), showmeans=0, showextrema=0,
showmedians=1)
fig, ax = plt.subplots()
ax = plt.axes()
ax.violinplot(data, statistics_function_list=[np.median],
positions=range(4), showmeans=0, showextrema=0,
showmedians=0)


@image_comparison(baseline_images=['violinplot_vert_showall'],
Expand Down Expand Up @@ -2395,6 +2407,31 @@ def test_vert_violinplot_custompoints_200():
showmedians=0, points=200)


@image_comparison(baseline_images=['violinplot_vert_showstdev',
'violinplot_vert_show80_20percentiles'],
extensions=['png'])
def test_vert_violinplot_showcustomstat():
ax = plt.axes()
# First 9 digits of frac(sqrt(31))
np.random.seed(567764362)
data = [np.random.normal(size=100) for i in range(4)]
func_list = [lambda x: np.mean(x) + np.std(x),
lambda x: np.mean(x) - np.std(x)]
ax.violinplot(data, statistics_function_list=func_list,
positions=range(4), showmeans=0,
showextrema=0, showmedians=0)
fig, ax = plt.subplots()
ax = plt.axes()
from matplotlib.axes._axes import ViolinStatFunc
percentile95 = ViolinStatFunc(np.percentile,
alias='95 percentile', args=[80])
percentile5 = ViolinStatFunc(np.percentile,
alias='5 percentile', args=[20])
ax.violinplot(data, statistics_function_list=[percentile95, percentile5],
positions=range(4), showmeans=0,
showextrema=0, showmedians=0)


@image_comparison(baseline_images=['violinplot_horiz_baseline'],
extensions=['png'])
def test_horiz_violinplot_baseline():
Expand All @@ -2406,7 +2443,8 @@ def test_horiz_violinplot_baseline():
showextrema=0, showmedians=0)


@image_comparison(baseline_images=['violinplot_horiz_showmedians'],
@image_comparison(baseline_images=['violinplot_horiz_showmedians',
'violinplot_horiz_showmedians'],
extensions=['png'])
def test_horiz_violinplot_showmedians():
ax = plt.axes()
Expand All @@ -2415,9 +2453,15 @@ def test_horiz_violinplot_showmedians():
data = [np.random.normal(size=100) for i in range(4)]
ax.violinplot(data, positions=range(4), vert=False, showmeans=0,
showextrema=0, showmedians=1)
fig, ax = plt.subplots()
ax = plt.axes()
ax.violinplot(data, statistics_function_list=[np.median],
positions=range(4), vert=False,
showmeans=0, showextrema=0, showmedians=0)


@image_comparison(baseline_images=['violinplot_horiz_showmeans'],
@image_comparison(baseline_images=['violinplot_horiz_showmeans',
'violinplot_horiz_showmeans'],
extensions=['png'])
def test_horiz_violinplot_showmeans():
ax = plt.axes()
Expand All @@ -2426,6 +2470,11 @@ def test_horiz_violinplot_showmeans():
data = [np.random.normal(size=100) for i in range(4)]
ax.violinplot(data, positions=range(4), vert=False, showmeans=1,
showextrema=0, showmedians=0)
fig, ax = plt.subplots()
ax = plt.axes()
ax.violinplot(data, statistics_function_list=[np.mean],
positions=range(4), vert=False,
showmeans=0, showmedians=0, showextrema=0)


@image_comparison(baseline_images=['violinplot_horiz_showextrema'],
Expand Down Expand Up @@ -2472,6 +2521,31 @@ def test_horiz_violinplot_custompoints_200():
showextrema=0, showmedians=0, points=200)


@image_comparison(baseline_images=['violinplot_horiz_showstdev',
'violinplot_horiz_show95_5percentiles'],
extensions=['png'])
def test_horiz_violinplot_showcustomstat():
ax = plt.axes()
# First 9 digits of frac(sqrt(31))
np.random.seed(567764362)
data = [np.random.normal(size=100) for i in range(4)]
func_list = [lambda x: np.mean(x) + np.std(x),
lambda x: np.mean(x) - np.std(x)]
ax.violinplot(data, statistics_function_list=func_list,
positions=range(4), vert=False, showmeans=0,
showextrema=0, showmedians=0)
fig, ax = plt.subplots()
ax = plt.axes()
from matplotlib.axes._axes import ViolinStatFunc
percentile95 = ViolinStatFunc(np.percentile,
alias='95 percentile', args=[95])
percentile5 = ViolinStatFunc(np.percentile,
alias='5 percentile', args=[5])
ax.violinplot(data, statistics_function_list=[percentile95, percentile5],
positions=range(4), vert=False, showmeans=0,
showextrema=0, showmedians=0)


def test_violinplot_bad_positions():
ax = plt.axes()
# First 9 digits of frac(sqrt(47))
Expand Down