Skip to content

[ENH]: add Figure.align_titles() functionality #22793

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 14 commits into from
26 changes: 26 additions & 0 deletions examples/subplots_axes_and_figures/align_titles_demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
===============
Aligning Labels
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Aligning Labels
Aligning Titles

?

===============

Aligning titles using `.Figure.align_titles`.

Note that the title "Title 1" would normally be much closer to the
figure's axis.
"""
import matplotlib.pyplot as plt
fig, axs = plt.subplots(2, 2,
subplot_kw={"xlabel": "x", "ylabel": "", "title": "t"})
print(axs.shape)
axs[0][0].imshow(plt.np.zeros((3, 5)))
axs[0][1].imshow(plt.np.zeros((5, 3)))
axs[1][0].imshow(plt.np.zeros((1, 2)))
axs[1][1].imshow(plt.np.zeros((2, 1)))
axs[0][0].set_title('t2')
rowspan1 = axs[0][0].get_subplotspec().rowspan
print(rowspan1, rowspan1.start, rowspan1.stop)
rowspan2 = axs[1][1].get_subplotspec().rowspan
print(rowspan2, rowspan2.start, rowspan2.stop)

fig.align_titles()
plt.show()
13 changes: 12 additions & 1 deletion lib/matplotlib/axes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2950,7 +2950,7 @@ def _update_title_position(self, renderer):
_log.debug('title position was updated manually, not adjusting')
return

titles = (self.title, self._left_title, self._right_title)
titles = [self.title, self._left_title, self._right_title]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this required? It doesn't look like you are actually modifying titles.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
titles = [self.title, self._left_title, self._right_title]
titles = (self.title, self._left_title, self._right_title)


for title in titles:
x, _ = title.get_position()
Expand Down Expand Up @@ -3008,6 +3008,17 @@ def _update_title_position(self, renderer):
x, _ = title.get_position()
title.set_position((x, ymax))

# Align bboxes of grouped axes to highest in group
grouped_axs = self.figure._align_label_groups['title'] \
.get_siblings(self)
bb_ymax = None
ax_max = None
for ax in grouped_axs:
if bb_ymax is None or ax.bbox.ymax > bb_ymax:
bb_ymax = ax.bbox.ymax
ax_max = ax
self.bbox = ax_max.bbox

# Drawing
@martist.allow_rasterization
def draw(self, renderer):
Expand Down
58 changes: 57 additions & 1 deletion lib/matplotlib/figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,8 @@ def __init__(self, **kwargs):
# groupers to keep track of x and y labels we want to align.
# see self.align_xlabels and self.align_ylabels and
# axis._get_tick_boxes_siblings
self._align_label_groups = {"x": cbook.Grouper(), "y": cbook.Grouper()}
self._align_label_groups = {"x": cbook.Grouper(), "y": cbook.Grouper(),
"title": cbook.Grouper()}

self.figure = self
# list of child gridspecs for this figure
Expand Down Expand Up @@ -1241,6 +1242,60 @@ def align_xlabels(self, axs=None):
# grouper for groups of xlabels to align
self._align_label_groups['x'].join(ax, axc)

def align_titles(self, axs=None):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we want to add align_titles to align_labels? Though this would be a breaking change of align_labels.

"""
Align the titles of subplots in the same subplot column if title
alignment is being done automatically (i.e. the title position is
not manually set).

Alignment persists for draw events after this is called.

Parameters
----------
axs : list of `~matplotlib.axes.Axes`
Optional list of (or ndarray) `~matplotlib.axes.Axes`
to align the titles.
Default is to align all Axes on the figure.

See Also
--------
matplotlib.figure.Figure.align_xlabels
matplotlib.figure.Figure.align_ylabels
matplotlib.figure.Figure.align_labels

Notes
-----
This assumes that ``axs`` are from the same `.GridSpec`, so that
their `.SubplotSpec` positions correspond to figure positions.

Examples
--------
Example with titles::

fig, axs = subplots(1, 2,
subplot_kw={"xlabel": "x", "ylabel": "y", "title": "t"})
axs[0].imshow(zeros((3, 5)))
axs[1].imshow(zeros((5, 3)))
fig.align_labels()
fig.align_titles()
"""
if axs is None:
axs = self.axes
axs = np.ravel(axs)
axs = [ax for ax in axs if hasattr(ax, 'get_subplotspec')]
locs = ['left', 'center', 'right']
for ax in axs:
for loc in locs:
if ax.get_title(loc=loc):
_log.debug(' Working on: %s', ax.get_title(loc=loc))
rowspan = ax.get_subplotspec().rowspan
for axc in axs:
for loc in locs:
rowspanc = axc.get_subplotspec().rowspan
if rowspan.start == rowspanc.start or \
rowspan.stop == rowspanc.stop:
self._align_label_groups['title'].join(ax, axc)

def align_ylabels(self, axs=None):
"""
Align the ylabels of subplots in the same subplot column if label
Expand Down Expand Up @@ -2848,6 +2903,7 @@ def draw(self, renderer):

artists = self._get_draw_artists(renderer)
try:

renderer.open_group('figure', gid=self.get_gid())
if self.axes and self.get_layout_engine() is not None:
try:
Expand Down
32 changes: 32 additions & 0 deletions lib/matplotlib/tests/test_figure.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,38 @@ def test_align_labels_stray_axes():
np.testing.assert_allclose(yn[::2], yn[1::2])


# TODO add image comparison
@image_comparison(['figure_align_titles'], extensions=['png', 'svg'],
tol=0 if platform.machine() == 'x86_64' else 0.01)
def test_align_titles():
fig, axs = plt.subplots(2, 2,
subplot_kw={"xlabel": "x", "ylabel": "",
"title": "Title"})
axs[0][0].imshow(plt.np.zeros((5, 3)))
axs[0][1].imshow(plt.np.zeros((3, 5)))
axs[1][0].imshow(plt.np.zeros((2, 1)))
axs[1][1].imshow(plt.np.zeros((1, 2)))

axs[0][1].set_title('Title2', loc="left")

fig.align_titles()


## TODO add image comparison
@image_comparison(['figure_align_titles_param'], extensions=['png', 'svg'],
tol=0 if platform.machine() == 'x86_64' else 0.01)
def test_align_titles_param():
fig, axs = plt.subplots(2, 2,
subplot_kw={"xlabel": "x", "ylabel": "",
"title": "t"})
axs[0][0].imshow(plt.np.zeros((3, 5)))
axs[0][1].imshow(plt.np.zeros((5, 3)))
axs[1][0].imshow(plt.np.zeros((2, 1)))
axs[1][1].imshow(plt.np.zeros((1, 2)))

fig.align_titles([axs[0][0], axs[0][1]])


def test_figure_label():
# pyplot figure creation, selection, and closing with label/number/instance
plt.close('all')
Expand Down
18 changes: 18 additions & 0 deletions test_align_titles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import matplotlib as plt
fig, axs = plt.subplots(2, 2,
subplot_kw={"xlabel": "x", "ylabel": "", "title": "t"})
print(axs.shape)
axs[0][0].imshow(plt.zeros((3, 5)))
axs[0][1].imshow(plt.zeros((5, 3)))
axs[1][0].imshow(plt.zeros((1, 2)))
axs[1][1].imshow(plt.zeros((2, 1)))
axs[0][0].set_title('t2')
rowspan1 = axs[0][0].get_subplotspec().rowspan
print(rowspan1, rowspan1.start, rowspan1.stop)
rowspan2 = axs[1][1].get_subplotspec().rowspan
print(rowspan2, rowspan2.start, rowspan2.stop)

fig.align_labels()
fig.align_titles()
plt.show()
print("DONE")