diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index df28052e9a57..83d0e4d990c1 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -2990,6 +2990,16 @@ def _update_title_position(self, renderer): x, _ = title.get_position() title.set_position((x, ymax)) + # Align bboxes of grouped axes to highest in group + grouped_axs = self.figure._align_label_groups['title'].get_siblings(self) + bb_ymax = None + ax_max = None + for ax in grouped_axs: + if bb_ymax is None or ax.bbox.ymax > bb_ymax: + bb_ymax = ax.bbox.ymax + ax_max = ax + self.bbox = ax_max.bbox + # Drawing @martist.allow_rasterization def draw(self, renderer): diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index dedd3ecc652f..6c9d5cb89a67 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -183,10 +183,11 @@ def __init__(self, **kwargs): self._supxlabel = None self._supylabel = None - # groupers to keep track of x and y labels we want to align. - # see self.align_xlabels and self.align_ylabels and + # groupers to keep track of x, y and title labels we want to align. + # see self.align_xlabels, self.align_ylabels, self.align_titles and # axis._get_tick_boxes_siblings - self._align_label_groups = {"x": cbook.Grouper(), "y": cbook.Grouper()} + self._align_label_groups = {"x": cbook.Grouper(), "y": cbook.Grouper(), + "title": cbook.Grouper()} self.figure = self self._localaxes = [] # track all axes @@ -1327,6 +1328,68 @@ def subplots_adjust(self, left=None, bottom=None, right=None, top=None, ax._set_position(ax.get_subplotspec().get_position(self)) self.stale = True + def align_titles(self, axs=None): + """ + Align the titles of subplots in the same subplot column if title + alignment is being done automatically (i.e. the title position is + not manually set). + + Alignment persists for draw events after this is called. + + If the title is on the top, + it is aligned with titles on Axes with the same top-most row. + + Parameters + ---------- + axs : list of `~matplotlib.axes.Axes` + Optional list of (or `~numpy.ndarray`) `~matplotlib.axes.Axes` + to align the titles. + Default is to align all titles on the figure. + + See Also + -------- + matplotlib.figure.Figure.align_ylabels + matplotlib.figure.Figure.align_xlabels + matplotlib.figure.Figure.align_labels + + Notes + ----- + This assumes that ``axs`` are from the same `.GridSpec`, so that + their `.SubplotSpec` positions correspond to figure positions. + + Examples + -------- + Example with aligned titles on multiple rows:: + + fig, axs = plt.subplots(2, 2, + subplot_kw={"xlabel": "x", "ylabel": "", + "title": "Title"}) + axs[0][0].imshow(plt.np.zeros((5, 3))) + axs[0][1].imshow(plt.np.zeros((3, 5))) + axs[1][0].imshow(plt.np.zeros((2, 1))) + axs[1][1].imshow(plt.np.zeros((1, 2))) + + axs[0][1].set_title('Title2', loc="left") + fig.align_titles() + """ + if axs is None: + axs = self.axes + axs = [ax for ax in np.ravel(axs) if ax.get_subplotspec() is not None] + locs = ['left', 'center', 'right'] + for ax in axs: + for loc in locs: + if ax.get_title(loc=loc): + rowspan = ax.get_subplotspec().rowspan + for axc in axs: + rowspanc = axc.get_subplotspec().rowspan + if rowspan.start == rowspanc.start or \ + rowspan.stop == rowspanc.stop: + self._align_label_groups['title'].join(ax, axc) + + # Fixes the issue that the bbox is too small to fit the aligned + # title when saving the figure + self.canvas.draw_idle() + def align_xlabels(self, axs=None): """ Align the xlabels of subplots in the same subplot column if label diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index d8b06003cd00..0781cc316500 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -11,6 +11,7 @@ import numpy as np import pytest from PIL import Image +from matplotlib.testing.compare import compare_images import matplotlib as mpl from matplotlib import gridspec @@ -102,6 +103,37 @@ def test_align_labels_stray_axes(): np.testing.assert_allclose(yn[::2], yn[1::2]) +## TODO add image comparison +@image_comparison(['figure_align_titles'], extensions=['png', 'svg'], + tol=0 if platform.machine() == 'x86_64' else 0.01, style='mpl20') +def test_align_titles(): + fig, axs = plt.subplots(2, 2, + subplot_kw={"xlabel": "x", "ylabel": "", + "title": "Title"}, layout="constrained") + axs[0][0].imshow(plt.np.zeros((5, 3))) + axs[0][1].imshow(plt.np.zeros((3, 5))) + axs[1][0].imshow(plt.np.zeros((2, 1))) + axs[1][1].imshow(plt.np.zeros((1, 2))) + + axs[0][1].set_title('Title2', loc="left") + fig.align_titles() + + +## TODO add image comparison +@image_comparison(['figure_align_titles_param'], extensions=['png', 'svg'], + tol=0 if platform.machine() == 'x86_64' else 0.01, style='mpl20') +def test_align_titles_param(): + fig, axs = plt.subplots(2, 2, + subplot_kw={"xlabel": "x", "ylabel": "", + "title": "t"}, layout="constrained") + axs[0][0].imshow(plt.np.zeros((3, 5))) + axs[0][1].imshow(plt.np.zeros((5, 3))) + axs[1][0].imshow(plt.np.zeros((2, 1))) + axs[1][1].imshow(plt.np.zeros((1, 2))) + + fig.align_titles([axs[0][0], axs[0][1]]) + + def test_figure_label(): # pyplot figure creation, selection, and closing with label/number/instance plt.close('all')