Skip to content

Commit c3c41a3

Browse files
committed
Merge pull request #914 from braingram/master
more flexible axis sharing for pyplot.subplots
2 parents d17dbff + f1db8e8 commit c3c41a3

File tree

3 files changed

+206
-29
lines changed

3 files changed

+206
-29
lines changed

examples/pylab_examples/subplots_demo.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010
import numpy as np
1111

1212
# Simple data to display in various forms
13-
x = np.linspace(0, 2*np.pi, 400)
14-
y = np.sin(x**2)
13+
x = np.linspace(0, 2 * np.pi, 400)
14+
y = np.sin(x ** 2)
1515

1616
plt.close('all')
1717

@@ -37,25 +37,33 @@
3737
ax1.plot(x, y)
3838
ax1.set_title('Sharing both axes')
3939
ax2.scatter(x, y)
40-
ax3.scatter(x, 2*y**2-1,color='r')
40+
ax3.scatter(x, 2 * y ** 2 - 1, color='r')
4141
# Fine-tune figure; make subplots close to each other and hide x ticks for
4242
# all but bottom plot.
4343
f.subplots_adjust(hspace=0)
4444
plt.setp([a.get_xticklabels() for a in f.axes[:-1]], visible=False)
4545

46+
# row and column sharing
47+
f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex='col', sharey='row')
48+
ax1.plot(x, y)
49+
ax1.set_title('Sharing x per column, y per row')
50+
ax2.scatter(x, y)
51+
ax3.scatter(x, 2 * y ** 2 - 1, color='r')
52+
ax4.plot(x, 2 * y ** 2 - 1, color='r')
53+
4654
# Four axes, returned as a 2-d array
4755
f, axarr = plt.subplots(2, 2)
48-
axarr[0,0].plot(x, y)
49-
axarr[0,0].set_title('Axis [0,0]')
50-
axarr[0,1].scatter(x, y)
51-
axarr[0,1].set_title('Axis [0,1]')
52-
axarr[1,0].plot(x, y**2)
53-
axarr[1,0].set_title('Axis [1,0]')
54-
axarr[1,1].scatter(x, y**2)
55-
axarr[1,1].set_title('Axis [1,1]')
56+
axarr[0, 0].plot(x, y)
57+
axarr[0, 0].set_title('Axis [0,0]')
58+
axarr[0, 1].scatter(x, y)
59+
axarr[0, 1].set_title('Axis [0,1]')
60+
axarr[1, 0].plot(x, y ** 2)
61+
axarr[1, 0].set_title('Axis [1,0]')
62+
axarr[1, 1].scatter(x, y ** 2)
63+
axarr[1, 1].set_title('Axis [1,1]')
5664
# Fine-tune figure; hide x ticks for top plots and y ticks for right plots
57-
plt.setp([a.get_xticklabels() for a in axarr[0,:]], visible=False)
58-
plt.setp([a.get_yticklabels() for a in axarr[:,1]], visible=False)
65+
plt.setp([a.get_xticklabels() for a in axarr[0, :]], visible=False)
66+
plt.setp([a.get_yticklabels() for a in axarr[:, 1]], visible=False)
5967

6068
# Four polar axes
6169
plt.subplots(2, 2, subplot_kw=dict(polar=True))

lib/matplotlib/pyplot.py

+75-16
Original file line numberDiff line numberDiff line change
@@ -793,15 +793,27 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
793793
*ncols* : int
794794
Number of columns of the subplot grid. Defaults to 1.
795795
796-
*sharex* : bool
796+
*sharex* : string or bool
797797
If *True*, the X axis will be shared amongst all subplots. If
798798
*True* and you have multiple rows, the x tick labels on all but
799799
the last row of plots will have visible set to *False*
800-
801-
*sharey* : bool
800+
If a string must be one of "row", "col", "all", or "none".
801+
"all" has the same effect as *True*, "none" has the same effect
802+
as *False*.
803+
If "row", each subplot row will share a X axis.
804+
If "col", each subplot column will share a X axis and the x tick
805+
labels on all but the last row will have visible set to *False*.
806+
807+
*sharey* : string or bool
802808
If *True*, the Y axis will be shared amongst all subplots. If
803809
*True* and you have multiple columns, the y tick labels on all but
804810
the first column of plots will have visible set to *False*
811+
If a string must be one of "row", "col", "all", or "none".
812+
"all" has the same effect as *True*, "none" has the same effect
813+
as *False*.
814+
If "row", each subplot row will share a Y axis.
815+
If "col", each subplot column will share a Y axis and the y tick
816+
labels on all but the last row will have visible set to *False*.
805817
806818
*squeeze* : bool
807819
If *True*, extra dimensions are squeezed out from the
@@ -859,7 +871,36 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
859871
860872
# Four polar axes
861873
plt.subplots(2, 2, subplot_kw=dict(polar=True))
874+
875+
# Share a X axis with each column of subplots
876+
plt.subplots(2, 2, sharex='col')
877+
878+
# Share a Y axis with each row of subplots
879+
plt.subplots(2, 2, sharey='row')
880+
881+
# Share a X and Y axis with all subplots
882+
plt.subplots(2, 2, sharex='all', sharey='all')
883+
# same as
884+
plt.subplots(2, 2, sharex=True, sharey=True)
862885
"""
886+
# for backwards compatability
887+
if isinstance(sharex, bool):
888+
if sharex:
889+
sharex = "all"
890+
else:
891+
sharex = "none"
892+
if isinstance(sharey, bool):
893+
if sharey:
894+
sharey = "all"
895+
else:
896+
sharey = "none"
897+
share_values = ["all", "row", "col", "none"]
898+
if sharex not in share_values:
899+
raise ValueError("sharex [%s] must be one of %s" % \
900+
(sharex, share_values))
901+
if sharey not in share_values:
902+
raise ValueError("sharey [%s] must be one of %s" % \
903+
(sharey, share_values))
863904

864905
if subplot_kw is None:
865906
subplot_kw = {}
@@ -873,34 +914,52 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
873914

874915
# Create first subplot separately, so we can share it if requested
875916
ax0 = fig.add_subplot(nrows, ncols, 1, **subplot_kw)
876-
if sharex:
877-
subplot_kw['sharex'] = ax0
878-
if sharey:
879-
subplot_kw['sharey'] = ax0
917+
#if sharex:
918+
# subplot_kw['sharex'] = ax0
919+
#if sharey:
920+
# subplot_kw['sharey'] = ax0
880921
axarr[0] = ax0
881922

923+
r, c = np.mgrid[:nrows, :ncols]
924+
r = r.flatten() * ncols
925+
c = c.flatten()
926+
lookup = {
927+
"none": np.arange(nplots),
928+
"all": np.zeros(nplots, dtype=int),
929+
"row": r,
930+
"col": c,
931+
}
932+
sxs = lookup[sharex]
933+
sys = lookup[sharey]
934+
882935
# Note off-by-one counting because add_subplot uses the MATLAB 1-based
883936
# convention.
884937
for i in range(1, nplots):
885-
axarr[i] = fig.add_subplot(nrows, ncols, i+1, **subplot_kw)
886-
887-
938+
if sxs[i] == i:
939+
subplot_kw['sharex'] = None
940+
else:
941+
subplot_kw['sharex'] = axarr[sxs[i]]
942+
if sys[i] == i:
943+
subplot_kw['sharey'] = None
944+
else:
945+
subplot_kw['sharey'] = axarr[sys[i]]
946+
axarr[i] = fig.add_subplot(nrows, ncols, i + 1, **subplot_kw)
888947

889948
# returned axis array will be always 2-d, even if nrows=ncols=1
890949
axarr = axarr.reshape(nrows, ncols)
891950

892-
893951
# turn off redundant tick labeling
894-
if sharex and nrows>1:
952+
if sharex in ["col", "all"] and nrows > 1:
953+
#if sharex and nrows>1:
895954
# turn off all but the bottom row
896-
for ax in axarr[:-1,:].flat:
955+
for ax in axarr[:-1, :].flat:
897956
for label in ax.get_xticklabels():
898957
label.set_visible(False)
899958

900-
901-
if sharey and ncols>1:
959+
if sharey in ["row", "all"] and ncols > 1:
960+
#if sharey and ncols>1:
902961
# turn off all but the first column
903-
for ax in axarr[:,1:].flat:
962+
for ax in axarr[:, 1:].flat:
904963
for label in ax.get_yticklabels():
905964
label.set_visible(False)
906965

lib/matplotlib/tests/test_subplots.py

+110
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import numpy
2+
import matplotlib.pyplot as plt
3+
4+
from nose.tools import assert_raises
5+
6+
7+
def check_shared(results, f, axs):
8+
"""
9+
results is a 4 x 4 x 2 matrix of boolean values where
10+
if [i, j, 0] == True, X axis for subplots i and j should be shared
11+
if [i, j, 1] == False, Y axis for subplots i and j should not be shared
12+
"""
13+
shared_str = ['x', 'y']
14+
shared = [axs[0]._shared_x_axes, axs[0]._shared_y_axes]
15+
#shared = {
16+
# 'x': a1._shared_x_axes,
17+
# 'y': a1._shared_y_axes,
18+
# }
19+
tostr = lambda r: "not " if r else ""
20+
for i1 in xrange(len(axs)):
21+
for i2 in xrange(i1 + 1, len(axs)):
22+
for i3 in xrange(len(shared)):
23+
assert shared[i3].joined(axs[i1], axs[i2]) == \
24+
results[i1, i2, i3], \
25+
"axes %i and %i incorrectly %ssharing %s axis" % \
26+
(i1, i2, tostr(results[i1, i2, i3]), shared_str[i3])
27+
28+
29+
def check_visible(result, f, axs):
30+
tostr = lambda v: "invisible" if v else "visible"
31+
for (ax, vx, vy) in zip(axs, result['x'], result['y']):
32+
for l in ax.get_xticklabels():
33+
assert l.get_visible() == vx, \
34+
"X axis was incorrectly %s" % (tostr(vx))
35+
for l in ax.get_yticklabels():
36+
assert l.get_visible() == vy, \
37+
"Y axis was incorrectly %s" % (tostr(vy))
38+
39+
40+
def test_shared():
41+
rdim = (4, 4, 2)
42+
share = {
43+
'all': numpy.ones(rdim[:2], dtype=bool),
44+
'none': numpy.zeros(rdim[:2], dtype=bool),
45+
'row': numpy.array([
46+
[False, True, False, False],
47+
[True, False, False, False],
48+
[False, False, False, True],
49+
[False, False, True, False]]),
50+
'col': numpy.array([
51+
[False, False, True, False],
52+
[False, False, False, True],
53+
[True, False, False, False],
54+
[False, True, False, False]]),
55+
}
56+
visible = {
57+
'x': {
58+
'all': [False, False, True, True],
59+
'col': [False, False, True, True],
60+
'row': [True] * 4,
61+
'none': [True] * 4,
62+
False: [True] * 4,
63+
True: [False, False, True, True],
64+
},
65+
'y': {
66+
'all': [True, False, True, False],
67+
'col': [True] * 4,
68+
'row': [True, False, True, False],
69+
'none': [True] * 4,
70+
False: [True] * 4,
71+
True: [True, False, True, False],
72+
},
73+
}
74+
share[False] = share['none']
75+
share[True] = share['all']
76+
77+
# test default
78+
f, ((a1, a2), (a3, a4)) = plt.subplots(2, 2)
79+
axs = [a1, a2, a3, a4]
80+
check_shared(numpy.dstack((share['none'], share['none'])), \
81+
f, axs)
82+
plt.close(f)
83+
84+
# test all option combinations
85+
ops = [False, True, 'all', 'none', 'row', 'col']
86+
for xo in ops:
87+
for yo in ops:
88+
f, ((a1, a2), (a3, a4)) = plt.subplots(2, 2, sharex=xo, sharey=yo)
89+
axs = [a1, a2, a3, a4]
90+
check_shared(numpy.dstack((share[xo], share[yo])), \
91+
f, axs)
92+
check_visible(dict(x=visible['x'][xo], y=visible['y'][yo]), \
93+
f, axs)
94+
plt.close(f)
95+
96+
97+
def test_exceptions():
98+
# TODO should this test more options?
99+
with assert_raises(ValueError):
100+
plt.subplots(2, 2, sharex='blah')
101+
plt.subplots(2, 2, sharey='blah')
102+
103+
104+
def test_subplots():
105+
# things to test
106+
# - are axes actually shared?
107+
# - are tickmarks correctly hidden?
108+
test_shared()
109+
# - are exceptions thrown correctly
110+
test_exceptions()

0 commit comments

Comments
 (0)