Skip to content

Commit 19aac39

Browse files
authored
Merge pull request #19467 from anntzer/v
Shorten the implementation of violin().
2 parents 81f0624 + f450425 commit 19aac39

File tree

1 file changed

+36
-64
lines changed

1 file changed

+36
-64
lines changed

lib/matplotlib/axes/_axes.py

+36-64
Original file line numberDiff line numberDiff line change
@@ -8106,18 +8106,18 @@ def violin(self, vpstats, positions=None, vert=True, widths=0.5,
81068106
- ``cquantiles``: A `~.collections.LineCollection` instance created
81078107
to identify the quantiles values of each of the violin's
81088108
distribution.
8109-
81108109
"""
81118110

81128111
# Statistical quantities to be plotted on the violins
81138112
means = []
81148113
mins = []
81158114
maxes = []
81168115
medians = []
8117-
quantiles = np.asarray([])
8116+
quantiles = []
8117+
8118+
qlens = [] # Number of quantiles in each dataset.
81188119

8119-
# Collections to be returned
8120-
artists = {}
8120+
artists = {} # Collections to be returned
81218121

81228122
N = len(vpstats)
81238123
datashape_message = ("List of violinplot statistics and `{0}` "
@@ -8135,84 +8135,56 @@ def violin(self, vpstats, positions=None, vert=True, widths=0.5,
81358135
elif len(widths) != N:
81368136
raise ValueError(datashape_message.format("widths"))
81378137

8138-
# Calculate ranges for statistics lines
8139-
pmins = -0.25 * np.array(widths) + positions
8140-
pmaxes = 0.25 * np.array(widths) + positions
8138+
# Calculate ranges for statistics lines (shape (2, N)).
8139+
line_ends = [[-0.25], [0.25]] * np.array(widths) + positions
8140+
8141+
# Colors.
8142+
if rcParams['_internal.classic_mode']:
8143+
fillcolor = 'y'
8144+
linecolor = 'r'
8145+
else:
8146+
fillcolor = linecolor = self._get_lines.get_next_color()
81418147

81428148
# Check whether we are rendering vertically or horizontally
81438149
if vert:
81448150
fill = self.fill_betweenx
8145-
perp_lines = self.hlines
8146-
par_lines = self.vlines
8151+
perp_lines = functools.partial(self.hlines, colors=linecolor)
8152+
par_lines = functools.partial(self.vlines, colors=linecolor)
81478153
else:
81488154
fill = self.fill_between
8149-
perp_lines = self.vlines
8150-
par_lines = self.hlines
8151-
8152-
if rcParams['_internal.classic_mode']:
8153-
fillcolor = 'y'
8154-
edgecolor = 'r'
8155-
else:
8156-
fillcolor = edgecolor = self._get_lines.get_next_color()
8155+
perp_lines = functools.partial(self.vlines, colors=linecolor)
8156+
par_lines = functools.partial(self.hlines, colors=linecolor)
81578157

81588158
# Render violins
81598159
bodies = []
81608160
for stats, pos, width in zip(vpstats, positions, widths):
8161-
# The 0.5 factor reflects the fact that we plot from v-p to
8162-
# v+p
8161+
# The 0.5 factor reflects the fact that we plot from v-p to v+p.
81638162
vals = np.array(stats['vals'])
81648163
vals = 0.5 * width * vals / vals.max()
8165-
bodies += [fill(stats['coords'],
8166-
-vals + pos,
8167-
vals + pos,
8168-
facecolor=fillcolor,
8169-
alpha=0.3)]
8164+
bodies += [fill(stats['coords'], -vals + pos, vals + pos,
8165+
facecolor=fillcolor, alpha=0.3)]
81708166
means.append(stats['mean'])
81718167
mins.append(stats['min'])
81728168
maxes.append(stats['max'])
81738169
medians.append(stats['median'])
8174-
q = stats.get('quantiles')
8175-
if q is not None:
8176-
# If exist key quantiles, assume it's a list of floats
8177-
quantiles = np.concatenate((quantiles, q))
8170+
q = stats.get('quantiles') # a list of floats, or None
8171+
if q is None:
8172+
q = []
8173+
quantiles.extend(q)
8174+
qlens.append(len(q))
81788175
artists['bodies'] = bodies
81798176

8180-
# Render means
8181-
if showmeans:
8182-
artists['cmeans'] = perp_lines(means, pmins, pmaxes,
8183-
colors=edgecolor)
8184-
8185-
# Render extrema
8186-
if showextrema:
8187-
artists['cmaxes'] = perp_lines(maxes, pmins, pmaxes,
8188-
colors=edgecolor)
8189-
artists['cmins'] = perp_lines(mins, pmins, pmaxes,
8190-
colors=edgecolor)
8191-
artists['cbars'] = par_lines(positions, mins, maxes,
8192-
colors=edgecolor)
8193-
8194-
# Render medians
8195-
if showmedians:
8196-
artists['cmedians'] = perp_lines(medians,
8197-
pmins,
8198-
pmaxes,
8199-
colors=edgecolor)
8200-
8201-
# Render quantile values
8202-
if quantiles.size > 0:
8203-
# Recalculate ranges for statistics lines for quantiles.
8204-
# ppmins are the left end of quantiles lines
8205-
ppmins = np.asarray([])
8206-
# pmaxes are the right end of quantiles lines
8207-
ppmaxs = np.asarray([])
8208-
for stats, cmin, cmax in zip(vpstats, pmins, pmaxes):
8209-
q = stats.get('quantiles')
8210-
if q is not None:
8211-
ppmins = np.concatenate((ppmins, [cmin] * np.size(q)))
8212-
ppmaxs = np.concatenate((ppmaxs, [cmax] * np.size(q)))
8213-
# Start rendering
8214-
artists['cquantiles'] = perp_lines(quantiles, ppmins, ppmaxs,
8215-
colors=edgecolor)
8177+
if showmeans: # Render means
8178+
artists['cmeans'] = perp_lines(means, *line_ends)
8179+
if showextrema: # Render extrema
8180+
artists['cmaxes'] = perp_lines(maxes, *line_ends)
8181+
artists['cmins'] = perp_lines(mins, *line_ends)
8182+
artists['cbars'] = par_lines(positions, mins, maxes)
8183+
if showmedians: # Render medians
8184+
artists['cmedians'] = perp_lines(medians, *line_ends)
8185+
if quantiles: # Render quantiles: each width is repeated qlen times.
8186+
artists['cquantiles'] = perp_lines(
8187+
quantiles, *np.repeat(line_ends, qlens, axis=1))
82168188

82178189
return artists
82188190

0 commit comments

Comments
 (0)