Skip to content

Commit 843db97

Browse files
committed
Shorten the implementation of violin().
Just some standard rewriting.
1 parent 4da4cf9 commit 843db97

File tree

1 file changed

+36
-64
lines changed

1 file changed

+36
-64
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 36 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -8100,18 +8100,18 @@ def violin(self, vpstats, positions=None, vert=True, widths=0.5,
81008100
- ``cquantiles``: A `~.collections.LineCollection` instance created
81018101
to identify the quantiles values of each of the violin's
81028102
distribution.
8103-
81048103
"""
81058104

81068105
# Statistical quantities to be plotted on the violins
81078106
means = []
81088107
mins = []
81098108
maxes = []
81108109
medians = []
8111-
quantiles = np.asarray([])
8110+
quantiles = []
8111+
8112+
qlens = [] # Number of quantiles in each dataset.
81128113

8113-
# Collections to be returned
8114-
artists = {}
8114+
artists = {} # Collections to be returned
81158115

81168116
N = len(vpstats)
81178117
datashape_message = ("List of violinplot statistics and `{0}` "
@@ -8129,84 +8129,56 @@ def violin(self, vpstats, positions=None, vert=True, widths=0.5,
81298129
elif len(widths) != N:
81308130
raise ValueError(datashape_message.format("widths"))
81318131

8132-
# Calculate ranges for statistics lines
8133-
pmins = -0.25 * np.array(widths) + positions
8134-
pmaxes = 0.25 * np.array(widths) + positions
8132+
# Calculate ranges for statistics lines (shape (2, N)).
8133+
line_ends = [[-0.25], [0.25]] * np.array(widths) + positions
8134+
8135+
# Colors.
8136+
if rcParams['_internal.classic_mode']:
8137+
fillcolor = 'y'
8138+
linecolor = 'r'
8139+
else:
8140+
fillcolor = linecolor = self._get_lines.get_next_color()
81358141

81368142
# Check whether we are rendering vertically or horizontally
81378143
if vert:
81388144
fill = self.fill_betweenx
8139-
perp_lines = self.hlines
8140-
par_lines = self.vlines
8145+
perp_lines = functools.partial(self.hlines, colors=linecolor)
8146+
par_lines = functools.partial(self.vlines, colors=linecolor)
81418147
else:
81428148
fill = self.fill_between
8143-
perp_lines = self.vlines
8144-
par_lines = self.hlines
8145-
8146-
if rcParams['_internal.classic_mode']:
8147-
fillcolor = 'y'
8148-
edgecolor = 'r'
8149-
else:
8150-
fillcolor = edgecolor = self._get_lines.get_next_color()
8149+
perp_lines = functools.partial(self.vlines, colors=linecolor)
8150+
par_lines = functools.partial(self.hlines, colors=linecolor)
81518151

81528152
# Render violins
81538153
bodies = []
81548154
for stats, pos, width in zip(vpstats, positions, widths):
8155-
# The 0.5 factor reflects the fact that we plot from v-p to
8156-
# v+p
8155+
# The 0.5 factor reflects the fact that we plot from v-p to v+p.
81578156
vals = np.array(stats['vals'])
81588157
vals = 0.5 * width * vals / vals.max()
8159-
bodies += [fill(stats['coords'],
8160-
-vals + pos,
8161-
vals + pos,
8162-
facecolor=fillcolor,
8163-
alpha=0.3)]
8158+
bodies += [fill(stats['coords'], -vals + pos, vals + pos,
8159+
facecolor=fillcolor, alpha=0.3)]
81648160
means.append(stats['mean'])
81658161
mins.append(stats['min'])
81668162
maxes.append(stats['max'])
81678163
medians.append(stats['median'])
8168-
q = stats.get('quantiles')
8169-
if q is not None:
8170-
# If exist key quantiles, assume it's a list of floats
8171-
quantiles = np.concatenate((quantiles, q))
8164+
q = stats.get('quantiles') # a list of floats, or None
8165+
if q is None:
8166+
q = []
8167+
quantiles.extend(q)
8168+
qlens.append(len(q))
81728169
artists['bodies'] = bodies
81738170

8174-
# Render means
8175-
if showmeans:
8176-
artists['cmeans'] = perp_lines(means, pmins, pmaxes,
8177-
colors=edgecolor)
8178-
8179-
# Render extrema
8180-
if showextrema:
8181-
artists['cmaxes'] = perp_lines(maxes, pmins, pmaxes,
8182-
colors=edgecolor)
8183-
artists['cmins'] = perp_lines(mins, pmins, pmaxes,
8184-
colors=edgecolor)
8185-
artists['cbars'] = par_lines(positions, mins, maxes,
8186-
colors=edgecolor)
8187-
8188-
# Render medians
8189-
if showmedians:
8190-
artists['cmedians'] = perp_lines(medians,
8191-
pmins,
8192-
pmaxes,
8193-
colors=edgecolor)
8194-
8195-
# Render quantile values
8196-
if quantiles.size > 0:
8197-
# Recalculate ranges for statistics lines for quantiles.
8198-
# ppmins are the left end of quantiles lines
8199-
ppmins = np.asarray([])
8200-
# pmaxes are the right end of quantiles lines
8201-
ppmaxs = np.asarray([])
8202-
for stats, cmin, cmax in zip(vpstats, pmins, pmaxes):
8203-
q = stats.get('quantiles')
8204-
if q is not None:
8205-
ppmins = np.concatenate((ppmins, [cmin] * np.size(q)))
8206-
ppmaxs = np.concatenate((ppmaxs, [cmax] * np.size(q)))
8207-
# Start rendering
8208-
artists['cquantiles'] = perp_lines(quantiles, ppmins, ppmaxs,
8209-
colors=edgecolor)
8171+
if showmeans: # Render means
8172+
artists['cmeans'] = perp_lines(means, *line_ends)
8173+
if showextrema: # Render extrema
8174+
artists['cmaxes'] = perp_lines(maxes, *line_ends)
8175+
artists['cmins'] = perp_lines(mins, *line_ends)
8176+
artists['cbars'] = par_lines(positions, mins, maxes)
8177+
if showmedians: # Render medians
8178+
artists['cmedians'] = perp_lines(medians, *line_ends)
8179+
if quantiles: # Render quantiles: each width is repeated qlen times.
8180+
idxs = [idx for idx, qlen in enumerate(qlens) for _ in range(qlen)]
8181+
artists['cquantiles'] = perp_lines(quantiles, *line_ends[:, idxs])
82108182

82118183
return artists
82128184

0 commit comments

Comments
 (0)