Skip to content

Commit 08db92d

Browse files
committed
Updated example as per suggestions in issue #7251
1 parent e794622 commit 08db92d

File tree

1 file changed

+47
-60
lines changed

1 file changed

+47
-60
lines changed

examples/statistics/customized_violin_demo.py

Lines changed: 47 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -18,87 +18,74 @@
1818
import numpy as np
1919

2020

21-
# functions to calculate percentiles and adjacent values
22-
def percentile(vals, p):
23-
N = len(vals)
24-
n = p*(N+1)
25-
k = int(n)
26-
d = n-k
27-
if k <= 0:
28-
return vals[0]
29-
if k >= N:
30-
return vals[N-1]
31-
return vals[k-1] + d*(vals[k] - vals[k-1])
32-
33-
3421
def adjacent_values(vals):
35-
q1 = percentile(vals, 0.25)
36-
q3 = percentile(vals, 0.75)
37-
iqr = q3 - q1 # inter-quartile range
38-
22+
q1, q3 = np.percentile(vals, [25, 75])
23+
# inter-quartile range iqr
24+
iqr = q3 - q1
3925
# upper adjacent values
4026
uav = q3 + iqr * 1.5
41-
if uav > vals[-1]:
42-
uav = vals[-1]
43-
if uav < q3:
44-
uav = q3
45-
27+
uav = np.clip(uav, q3, vals[-1])
4628
# lower adjacent values
4729
lav = q1 - iqr * 1.5
48-
if lav < vals[0]:
49-
lav = vals[0]
50-
if lav > q1:
51-
lav = q1
30+
lav = np.clip(lav, q1, vals[0])
5231
return [lav, uav]
5332

5433

34+
def set_axis_style(ax, labels):
35+
ax.get_xaxis().set_tick_params(direction='out')
36+
ax.xaxis.set_ticks_position('bottom')
37+
ax.set_xticks(np.arange(1, len(labels) + 1))
38+
ax.set_xticklabels(labels)
39+
ax.set_xlim(0.25, len(labels) + 0.75)
40+
ax.set_xlabel('Sample name')
41+
42+
5543
# create test data
5644
np.random.seed(123)
57-
dat = [np.random.normal(0, std, 100) for std in range(1, 5)]
58-
lab = ['A', 'B', 'C', 'D'] # labels
59-
med = [] # medians
60-
iqr = [] # inter-quantile ranges
61-
avs = [] # upper and lower adjacent values
62-
for arr in dat:
63-
sarr = sorted(arr)
64-
med.append(percentile(sarr, 0.5))
65-
iqr.append([percentile(sarr, 0.25), percentile(sarr, 0.75)])
66-
avs.append(adjacent_values(sarr))
67-
68-
# plot the violins
69-
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4),
70-
sharey=True)
71-
_ = ax1.violinplot(dat)
72-
parts = ax2.violinplot(dat, showmeans=False, showmedians=False,
73-
showextrema=False)
45+
dat = [sorted(np.random.normal(0, std, 100)) for std in range(1, 5)]
46+
47+
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(9, 4), sharey=True)
7448

49+
# plot the default violin
7550
ax1.set_title('Default violin plot')
76-
ax2.set_title('Customized violin plot')
51+
ax1.set_ylabel('Observed values')
52+
ax1.violinplot(dat)
7753

78-
# plot whiskers as thin lines, quartiles as fat lines,
79-
# and medians as points
80-
for i in range(len(med)):
81-
# whiskers
82-
ax2.plot([i + 1, i + 1], avs[i], '-', color='black', linewidth=1)
83-
ax2.plot([i + 1, i + 1], iqr[i], '-', color='black', linewidth=5)
84-
ax2.plot(i + 1, med[i], 'o', color='white',
85-
markersize=6, markeredgecolor='none')
54+
# customized violin
55+
ax2.set_title('Customized violin plot')
56+
parts = ax2.violinplot(
57+
dat, showmeans=False, showmedians=False,
58+
showextrema=False)
8659

8760
# customize colors
8861
for pc in parts['bodies']:
8962
pc.set_facecolor('#D43F3A')
9063
pc.set_edgecolor('black')
9164
pc.set_alpha(1)
9265

93-
ax1.set_ylabel('Observed values')
66+
# medians
67+
med = [np.percentile(sarr, 50) for sarr in dat]
68+
# inter-quartile ranges
69+
iqr = [[np.percentile(sarr, 25), np.percentile(sarr, 75)] for sarr in dat]
70+
# upper and lower adjacent values
71+
avs = [adjacent_values(sarr) for sarr in dat]
72+
73+
# plot whiskers as thin lines, quartiles as fat lines,
74+
# and medians as points
75+
for i, median in enumerate(med):
76+
# whiskers
77+
ax2.plot([i + 1, i + 1], avs[i], '-', color='black', linewidth=1)
78+
# quartiles
79+
ax2.plot([i + 1, i + 1], iqr[i], '-', color='black', linewidth=5)
80+
# medians
81+
ax2.plot(
82+
i + 1, median, 'o', color='white',
83+
markersize=6, markeredgecolor='none')
84+
85+
# set style for the axes
86+
labels = ['A', 'B', 'C', 'D'] # labels
9487
for ax in [ax1, ax2]:
95-
ax.get_xaxis().set_tick_params(direction='out')
96-
ax.xaxis.set_ticks_position('bottom')
97-
ax.set_xticks(np.arange(1, len(lab) + 1))
98-
ax.set_xticklabels(lab)
99-
ax.set_xlim(0.25, len(lab) + 0.75)
100-
ax.set_xlabel('Sample name')
88+
set_axis_style(ax, labels)
10189

10290
plt.subplots_adjust(bottom=0.15, wspace=0.05)
103-
10491
plt.show()

0 commit comments

Comments
 (0)