From b1651f450c7deebc8031d4ee55b9b1aaf46aeadc Mon Sep 17 00:00:00 2001 From: Jody Klymak Date: Mon, 12 Apr 2021 09:20:39 -0700 Subject: [PATCH] FIX: subfigure indexing error --- lib/matplotlib/figure.py | 4 ++-- lib/matplotlib/tests/test_figure.py | 22 ++++++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index a9d435eb3919..c4800f1a0b49 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -1992,11 +1992,11 @@ def _redo_transform_rel_fig(self, bbox=None): x0 = 0 if not self._subplotspec.is_first_col(): - x0 += np.sum(wr[self._subplotspec.colspan.start - 1]) / np.sum(wr) + x0 += np.sum(wr[:self._subplotspec.colspan.start]) / np.sum(wr) y0 = 0 if not self._subplotspec.is_last_row(): - y0 += 1 - (np.sum(hr[self._subplotspec.rowspan.stop - 1]) / + y0 += 1 - (np.sum(hr[:self._subplotspec.rowspan.stop]) / np.sum(hr)) if self.bbox_relative is None: diff --git a/lib/matplotlib/tests/test_figure.py b/lib/matplotlib/tests/test_figure.py index 15f084fc1516..d139d3b39f60 100644 --- a/lib/matplotlib/tests/test_figure.py +++ b/lib/matplotlib/tests/test_figure.py @@ -967,6 +967,28 @@ def test_subfigure_double(): axsRight = subfigs[1].subplots(2, 2) +def test_subfigure_spanning(): + # test that subfigures get laid out properly... + fig = plt.figure(constrained_layout=True) + gs = fig.add_gridspec(3, 3) + sub_figs = [ + fig.add_subfigure(gs[0, 0]), + fig.add_subfigure(gs[0:2, 1]), + fig.add_subfigure(gs[2, 1:3]), + ] + + w = 640 + h = 480 + np.testing.assert_allclose(sub_figs[0].bbox.min, [0., h * 2/3]) + np.testing.assert_allclose(sub_figs[0].bbox.max, [w / 3, h]) + + np.testing.assert_allclose(sub_figs[1].bbox.min, [w / 3, h / 3]) + np.testing.assert_allclose(sub_figs[1].bbox.max, [w * 2/3, h]) + + np.testing.assert_allclose(sub_figs[2].bbox.min, [w / 3, 0]) + np.testing.assert_allclose(sub_figs[2].bbox.max, [w, h / 3]) + + def test_add_subplot_kwargs(): # fig.add_subplot() always creates new axes, even if axes kwargs differ. fig = plt.figure()