@@ -793,15 +793,27 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
793
793
*ncols* : int
794
794
Number of columns of the subplot grid. Defaults to 1.
795
795
796
- *sharex* : bool
796
+ *sharex* : string or bool
797
797
If *True*, the X axis will be shared amongst all subplots. If
798
798
*True* and you have multiple rows, the x tick labels on all but
799
799
the last row of plots will have visible set to *False*
800
-
801
- *sharey* : bool
800
+ If a string must be one of "row", "col", "all", or "none".
801
+ "all" has the same effect as *True*, "none" has the same effect
802
+ as *False*.
803
+ If "row", each subplot row will share a X axis.
804
+ If "col", each subplot column will share a X axis and the x tick
805
+ labels on all but the last row will have visible set to *False*.
806
+
807
+ *sharey* : string or bool
802
808
If *True*, the Y axis will be shared amongst all subplots. If
803
809
*True* and you have multiple columns, the y tick labels on all but
804
810
the first column of plots will have visible set to *False*
811
+ If a string must be one of "row", "col", "all", or "none".
812
+ "all" has the same effect as *True*, "none" has the same effect
813
+ as *False*.
814
+ If "row", each subplot row will share a Y axis.
815
+ If "col", each subplot column will share a Y axis and the y tick
816
+ labels on all but the last row will have visible set to *False*.
805
817
806
818
*squeeze* : bool
807
819
If *True*, extra dimensions are squeezed out from the
@@ -859,7 +871,36 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
859
871
860
872
# Four polar axes
861
873
plt.subplots(2, 2, subplot_kw=dict(polar=True))
874
+
875
+ # Share a X axis with each column of subplots
876
+ plt.subplots(2, 2, sharex='col')
877
+
878
+ # Share a Y axis with each row of subplots
879
+ plt.subplots(2, 2, sharey='row')
880
+
881
+ # Share a X and Y axis with all subplots
882
+ plt.subplots(2, 2, sharex='all', sharey='all')
883
+ # same as
884
+ plt.subplots(2, 2, sharex=True, sharey=True)
862
885
"""
886
+ # for backwards compatability
887
+ if isinstance (sharex , bool ):
888
+ if sharex :
889
+ sharex = "all"
890
+ else :
891
+ sharex = "none"
892
+ if isinstance (sharey , bool ):
893
+ if sharey :
894
+ sharey = "all"
895
+ else :
896
+ sharey = "none"
897
+ share_values = ["all" , "row" , "col" , "none" ]
898
+ if sharex not in share_values :
899
+ raise ValueError ("sharex [%s] must be one of %s" % \
900
+ (sharex , share_values ))
901
+ if sharey not in share_values :
902
+ raise ValueError ("sharey [%s] must be one of %s" % \
903
+ (sharey , share_values ))
863
904
864
905
if subplot_kw is None :
865
906
subplot_kw = {}
@@ -873,34 +914,52 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
873
914
874
915
# Create first subplot separately, so we can share it if requested
875
916
ax0 = fig .add_subplot (nrows , ncols , 1 , ** subplot_kw )
876
- if sharex :
877
- subplot_kw ['sharex' ] = ax0
878
- if sharey :
879
- subplot_kw ['sharey' ] = ax0
917
+ # if sharex:
918
+ # subplot_kw['sharex'] = ax0
919
+ # if sharey:
920
+ # subplot_kw['sharey'] = ax0
880
921
axarr [0 ] = ax0
881
922
923
+ r , c = np .mgrid [:nrows , :ncols ]
924
+ r = r .flatten () * ncols
925
+ c = c .flatten ()
926
+ lookup = {
927
+ "none" : np .arange (nplots ),
928
+ "all" : np .zeros (nplots , dtype = int ),
929
+ "row" : r ,
930
+ "col" : c ,
931
+ }
932
+ sxs = lookup [sharex ]
933
+ sys = lookup [sharey ]
934
+
882
935
# Note off-by-one counting because add_subplot uses the MATLAB 1-based
883
936
# convention.
884
937
for i in range (1 , nplots ):
885
- axarr [i ] = fig .add_subplot (nrows , ncols , i + 1 , ** subplot_kw )
886
-
887
-
938
+ if sxs [i ] == i :
939
+ subplot_kw ['sharex' ] = None
940
+ else :
941
+ subplot_kw ['sharex' ] = axarr [sxs [i ]]
942
+ if sys [i ] == i :
943
+ subplot_kw ['sharey' ] = None
944
+ else :
945
+ subplot_kw ['sharey' ] = axarr [sys [i ]]
946
+ axarr [i ] = fig .add_subplot (nrows , ncols , i + 1 , ** subplot_kw )
888
947
889
948
# returned axis array will be always 2-d, even if nrows=ncols=1
890
949
axarr = axarr .reshape (nrows , ncols )
891
950
892
-
893
951
# turn off redundant tick labeling
894
- if sharex and nrows > 1 :
952
+ if sharex in ["col" , "all" ] and nrows > 1 :
953
+ #if sharex and nrows>1:
895
954
# turn off all but the bottom row
896
- for ax in axarr [:- 1 ,:].flat :
955
+ for ax in axarr [:- 1 , :].flat :
897
956
for label in ax .get_xticklabels ():
898
957
label .set_visible (False )
899
958
900
-
901
- if sharey and ncols > 1 :
959
+ if sharey in [ "row" , "all" ] and ncols > 1 :
960
+ # if sharey and ncols>1:
902
961
# turn off all but the first column
903
- for ax in axarr [:,1 :].flat :
962
+ for ax in axarr [:, 1 :].flat :
904
963
for label in ax .get_yticklabels ():
905
964
label .set_visible (False )
906
965
0 commit comments