diff --git a/examples/subplots_axes_and_figures/align_titles_demo.py b/examples/subplots_axes_and_figures/align_titles_demo.py new file mode 100644 index 000000000000..42c8fe819fbe --- /dev/null +++ b/examples/subplots_axes_and_figures/align_titles_demo.py @@ -0,0 +1,26 @@ +""" +=============== +Aligning Labels +=============== + +Aligning titles using `.Figure.align_titles`. + +Note that the title "Title 1" would normally be much closer to the +figure's axis. +""" +import matplotlib.pyplot as plt +fig, axs = plt.subplots(2, 2, + subplot_kw={"xlabel": "x", "ylabel": "", "title": "t"}) +print(axs.shape) +axs[0][0].imshow(plt.np.zeros((3, 5))) +axs[0][1].imshow(plt.np.zeros((5, 3))) +axs[1][0].imshow(plt.np.zeros((1, 2))) +axs[1][1].imshow(plt.np.zeros((2, 1))) +axs[0][0].set_title('t2') +rowspan1 = axs[0][0].get_subplotspec().rowspan +print(rowspan1, rowspan1.start, rowspan1.stop) +rowspan2 = axs[1][1].get_subplotspec().rowspan +print(rowspan2, rowspan2.start, rowspan2.stop) + +fig.align_titles() +plt.show() diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 3be469189a3b..dddd0cf1930d 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -2950,7 +2950,7 @@ def _update_title_position(self, renderer): _log.debug('title position was updated manually, not adjusting') return - titles = (self.title, self._left_title, self._right_title) + titles = [self.title, self._left_title, self._right_title] for title in titles: x, _ = title.get_position() @@ -3008,6 +3008,17 @@ def _update_title_position(self, renderer): x, _ = title.get_position() title.set_position((x, ymax)) + # Align bboxes of grouped axes to highest in group + grouped_axs = self.figure._align_label_groups['title'] \ + .get_siblings(self) + bb_ymax = None + ax_max = None + for ax in grouped_axs: + if bb_ymax is None or ax.bbox.ymax > bb_ymax: + bb_ymax = ax.bbox.ymax + ax_max = ax + self.bbox = ax_max.bbox + # Drawing @martist.allow_rasterization def draw(self, renderer): diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 974e46d6e4b3..025c651edd22 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -173,7 +173,8 @@ def __init__(self, **kwargs): # groupers to keep track of x and y labels we want to align. # see self.align_xlabels and self.align_ylabels and # axis._get_tick_boxes_siblings - self._align_label_groups = {"x": cbook.Grouper(), "y": cbook.Grouper()} + self._align_label_groups = {"x": cbook.Grouper(), "y": cbook.Grouper(), + "title": cbook.Grouper()} self.figure = self # list of child gridspecs for this figure @@ -1241,6 +1242,60 @@ def align_xlabels(self, axs=None): # grouper for groups of xlabels to align self._align_label_groups['x'].join(ax, axc) + def align_titles(self, axs=None): + """ + Align the titles of subplots in the same subplot column if title + alignment is being done automatically (i.e. the title position is + not manually set). + + Alignment persists for draw events after this is called. + + Parameters + ---------- + axs : list of `~matplotlib.axes.Axes` + Optional list of (or ndarray) `~matplotlib.axes.Axes` + to align the titles. + Default is to align all Axes on the figure. + + See Also + -------- + matplotlib.figure.Figure.align_xlabels + matplotlib.figure.Figure.align_ylabels + matplotlib.figure.Figure.align_labels + + Notes + ----- + This assumes that ``axs`` are from the same `.GridSpec`, so that + their `.SubplotSpec` positions correspond to figure positions. + + Examples + -------- + Example with titles:: + + fig, axs = subplots(1, 2, + subplot_kw={"xlabel": "x", "ylabel": "y", "title": "t"}) + axs[0].imshow(zeros((3, 5))) + axs[1].imshow(zeros((5, 3))) + fig.align_labels() + fig.align_titles() + """ + if axs is None: + axs = self.axes + axs = np.ravel(axs) + axs = [ax for ax in axs if hasattr(ax, 'get_subplotspec')] + locs = ['left', 'center', 'right'] + for ax in axs: + for loc in locs: + if ax.get_title(loc=loc): + _log.debug(' Working on: %s', ax.get_title(loc=loc)) + rowspan = ax.get_subplotspec().rowspan + for axc in axs: + for loc in locs: + rowspanc = axc.get_subplotspec().rowspan + if rowspan.start == rowspanc.start or \ + rowspan.stop == rowspanc.stop: + self._align_label_groups['title'].join(ax, axc) + def align_ylabels(self, axs=None): """ Align the ylabels of subplots in the same subplot column if label @@ -2848,6 +2903,7 @@ def draw(self, renderer): artists = self._get_draw_artists(renderer) try: + renderer.open_group('figure', gid=self.get_gid()) if self.axes and self.get_layout_engine() is not None: try: diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index 9f68f63e5236..fb0fed6b311f 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -100,6 +100,38 @@ def test_align_labels_stray_axes(): np.testing.assert_allclose(yn[::2], yn[1::2]) +# TODO add image comparison +@image_comparison(['figure_align_titles'], extensions=['png', 'svg'], + tol=0 if platform.machine() == 'x86_64' else 0.01) +def test_align_titles(): + fig, axs = plt.subplots(2, 2, + subplot_kw={"xlabel": "x", "ylabel": "", + "title": "Title"}) + axs[0][0].imshow(plt.np.zeros((5, 3))) + axs[0][1].imshow(plt.np.zeros((3, 5))) + axs[1][0].imshow(plt.np.zeros((2, 1))) + axs[1][1].imshow(plt.np.zeros((1, 2))) + + axs[0][1].set_title('Title2', loc="left") + + fig.align_titles() + + +## TODO add image comparison +@image_comparison(['figure_align_titles_param'], extensions=['png', 'svg'], + tol=0 if platform.machine() == 'x86_64' else 0.01) +def test_align_titles_param(): + fig, axs = plt.subplots(2, 2, + subplot_kw={"xlabel": "x", "ylabel": "", + "title": "t"}) + axs[0][0].imshow(plt.np.zeros((3, 5))) + axs[0][1].imshow(plt.np.zeros((5, 3))) + axs[1][0].imshow(plt.np.zeros((2, 1))) + axs[1][1].imshow(plt.np.zeros((1, 2))) + + fig.align_titles([axs[0][0], axs[0][1]]) + + def test_figure_label(): # pyplot figure creation, selection, and closing with label/number/instance plt.close('all') diff --git a/test_align_titles.py b/test_align_titles.py new file mode 100644 index 000000000000..ec2957af27ea --- /dev/null +++ b/test_align_titles.py @@ -0,0 +1,18 @@ +import matplotlib as plt +fig, axs = plt.subplots(2, 2, + subplot_kw={"xlabel": "x", "ylabel": "", "title": "t"}) +print(axs.shape) +axs[0][0].imshow(plt.zeros((3, 5))) +axs[0][1].imshow(plt.zeros((5, 3))) +axs[1][0].imshow(plt.zeros((1, 2))) +axs[1][1].imshow(plt.zeros((2, 1))) +axs[0][0].set_title('t2') +rowspan1 = axs[0][0].get_subplotspec().rowspan +print(rowspan1, rowspan1.start, rowspan1.stop) +rowspan2 = axs[1][1].get_subplotspec().rowspan +print(rowspan2, rowspan2.start, rowspan2.stop) + +fig.align_labels() +fig.align_titles() +plt.show() +print("DONE")