diff --git a/lib/matplotlib/figure.py b/lib/matplotlib/figure.py index 5bbb786984d2..447f508194a1 100644 --- a/lib/matplotlib/figure.py +++ b/lib/matplotlib/figure.py @@ -2426,9 +2426,12 @@ def __init__(self, if isinstance(tight_layout, dict): self.get_layout_engine().set(**tight_layout) elif constrained_layout is not None: - self.set_layout_engine(layout='constrained') if isinstance(constrained_layout, dict): + self.set_layout_engine(layout='constrained') self.get_layout_engine().set(**constrained_layout) + elif constrained_layout: + self.set_layout_engine(layout='constrained') + else: # everything is None, so use default: self.set_layout_engine(layout=layout) diff --git a/lib/matplotlib/tests/test_constrainedlayout.py b/lib/matplotlib/tests/test_constrainedlayout.py index 35eb850fcd70..64906b74c3ff 100644 --- a/lib/matplotlib/tests/test_constrainedlayout.py +++ b/lib/matplotlib/tests/test_constrainedlayout.py @@ -656,3 +656,14 @@ def test_compressed1(): pos = axs[1, 2].get_position() np.testing.assert_allclose(pos.x1, 0.8618, atol=1e-3) np.testing.assert_allclose(pos.y0, 0.1934, atol=1e-3) + + +@pytest.mark.parametrize('arg, state', [ + (True, True), + (False, False), + ({}, True), + ({'rect': None}, True) +]) +def test_set_constrained_layout(arg, state): + fig, ax = plt.subplots(constrained_layout=arg) + assert fig.get_constrained_layout() is state