Skip to content

[Bug]: Having issues getting Figure.legend to work with constrained layout and SubFigure #29418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
busFred opened this issue Jan 7, 2025 · 12 comments
Labels
Community support Users in need of help.

Comments

@busFred
Copy link

busFred commented Jan 7, 2025

Bug summary

I created a plot with fig: Figure = plt.figure(layout="constrained") and add several rows of SubFigure with potentially unequal amount of columns using fig.subfigures. Since all the boxes in this parent figure have lines sharing same line style, I added a global legend box with fig.legend. When I save the figure with fig.savefig, legend box just get cropped off regardless of how i set things.

Code for reproduction

def plot_reward():
    fig: Figure = plt.figure()
    fig: Figure = plt.figure(layout="constrained")
    n_rows: int = math.ceil(len(results) / N_COLS)
    fig.set_figheight(1.7 * n_rows)
    fig.supxlabel("observations")
    fig.supylabel("held-out rewards")
    subfigs: Sequence[SubFigure] = fig.subfigures(nrows=n_rows, ncols=1)  # type:ignore
    label_set = list()
    line_set = list()
    results_l = [(k, v) for k, v in results.items()]
    end_idx: int = 0
    for _fig in subfigs:
        start_idx: int = end_idx
        end_idx: int = min(end_idx + N_COLS, len(results_l))
        _results = results_l[start_idx:end_idx]
        axs: Sequence[Axes] = _fig.subplots(
            nrows=1, ncols=len(_results), squeeze=False
        )[0]
        for ax, (name, strat_est_to_metrics_d) in zip(axs, _results):
            ax.set_title(name)
            ax.set_box_aspect(1.0)
            ax.set_xlim(0.0, 8500)
            ax.set_xticks([500, 4000, 8000])
            ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
            for strat_est, metrics in strat_est_to_metrics_d.items():
                if len(metrics) > 1:
                    pass
                else:
                    _metrics_d: pd.DataFrame = metrics[0][
                        ["val/reward", "train/n_obs"]
                    ].dropna()
                    ax.plot(
                        _metrics_d["train/n_obs"],
                        _metrics_d["val/reward"],
                        **STRAT_EST_TO_PLOT_KWARGS[strat_est],
                    )
            _lines, _labels = ax.get_legend_handles_labels()
            for _line, _label in zip(_lines, _labels):
                if _label not in label_set:
                    line_set.append(_line)
                    label_set.append(_label)
    fig.legend(
        line_set,
        label_set,
        loc="lower center",
        borderaxespad=-1.3,
        # bbox_to_anchor=(),
        frameon=True,
        ncol=3,
    )
    print(fig.get_default_bbox_extra_artists())
    fn: str = os.path.splitext(cfg_p.split("/")[-1])[0]
    plt.savefig(
        os.path.join(out_p, f"{fn}_alpha{alpha}_reward.png"),
        dpi=720,
        bbox_extra_artists=fig.get_default_bbox_extra_artists(),
        bbox_inches="tight",
    )
    plt.savefig(
        os.path.join(out_p, f"{fn}_alpha{alpha}_reward.svg"),
        bbox_extra_artists=fig.get_default_bbox_extra_artists(),
        bbox_inches="tight",
    )
    plt.show()
    plt.close()


plot_reward()

Actual outcome

lrc_alpha0 9_reward

Expected outcome

Both rows in the legend box should be saved but the second row is always cropped off regardless.

Additional information

result of print(fig.get_default_bbox_extra_artists()) seems to contain legend box. Just for whatever reason bbox_inches='tight' isn't considering it properly.

[Text(0.5, 0.01, 'observations'), Text(0.02, 0.5, 'held-out rewards'), <matplotlib.legend.Legend object at 0x78abd7a28df0>, <matplotlib.figure.SubFigure object at 0x78abd7a28550>, <matplotlib.figure.SubFigure object at 0x78abd7a2bdf0>, <matplotlib.spines.Spine object at 0x78abd7a2ad10>, <matplotlib.spines.Spine object at 0x78abd7a28970>, <matplotlib.spines.Spine object at 0x78abd7a285e0>, <matplotlib.spines.Spine object at 0x78abd7a29de0>, <matplotlib.patches.Rectangle object at 0x78abcc2b4fa0>, <matplotlib.spines.Spine object at 0x78abcc2b4d90>, <matplotlib.spines.Spine object at 0x78abcc2b4940>, <matplotlib.spines.Spine object at 0x78abcc2b4730>, <matplotlib.spines.Spine object at 0x78abcc2b45e0>, <matplotlib.patches.Rectangle object at 0x78abd7fc9c90>, <matplotlib.spines.Spine object at 0x78abd7fc80d0>, <matplotlib.spines.Spine object at 0x78abd7fca170>, <matplotlib.spines.Spine object at 0x78abd7fc95a0>, <matplotlib.spines.Spine object at 0x78abd7fc9690>, <matplotlib.patches.Rectangle object at 0x78abd7a748b0>, <matplotlib.spines.Spine object at 0x78abccd386d0>, <matplotlib.spines.Spine object at 0x78abccd3b3a0>, <matplotlib.spines.Spine object at 0x78abccd39d50>, <matplotlib.spines.Spine object at 0x78abccd3b430>, <matplotlib.patches.Rectangle object at 0x78abcc231bd0>, <matplotlib.spines.Spine object at 0x78abcc230b20>, <matplotlib.spines.Spine object at 0x78abcc230400>, <matplotlib.spines.Spine object at 0x78abcc2310c0>, <matplotlib.spines.Spine object at 0x78abcc230f40>, <matplotlib.patches.Rectangle object at 0x78abccd730d0>, <matplotlib.spines.Spine object at 0x78abcc3b2710>, <matplotlib.spines.Spine object at 0x78abcc3b27a0>, <matplotlib.spines.Spine object at 0x78abcc3b2230>, <matplotlib.spines.Spine object at 0x78abcc3b3e50>, <matplotlib.patches.Rectangle object at 0x78abcc3b1180>]

Operating system

Linux Mint 22

Matplotlib Version

3.10.0

Matplotlib Backend

print(matplotlib.get_backend())

Python version

3.10.14

Jupyter version

No response

Installation

pip

@busFred
Copy link
Author

busFred commented Jan 7, 2025

potentially related to #20736 but i'm not seeing how it's fixed.

@jklymak
Copy link
Member

jklymak commented Jan 7, 2025

Can you make a minimal self contained example? I'm also not clear why you are using subfigure for this layout instead of just two rows of subplots.

@rcomer rcomer added the status: needs clarification Issues that need more information to resolve. label Jan 7, 2025
@busFred
Copy link
Author

busFred commented Jan 9, 2025

@jklymak attached file contains the necessary python file and the pickled python dictionary and pandas dataframe required to generate the plot. The reason I used SubFigure is because I want to accommodate for jagged subplots, i.e. the first couple rows all having 3 columns and the last row have only one or two column that would be evenly spaced and centered. Here, the legend should be two rows but the second row is cropped off. alpha0 0_ts_accuracy

matplotlib_demo.tar.gz

@tacaswell
Copy link
Member

I hope the exact data does not matter? I (and I expect a fair number of other mpl devs) are not comfortable loading pickle files from issue (see the warning in https://docs.python.org/3/library/pickle.html TL;DR; using pickle to talk between two processes you control is fine, using loading a pickle written by a process you do not control is not).

Can you please reduce this to a minimal example (plotting random numbers or trig functions should show the effect) that we can copy-paste-run with no changes?

I note that you are using bbox_inches='tight' which tries to "shrink wrap" the saved figure so I suspect that constrained_layout is a red-herring.

@jklymak
Copy link
Member

jklymak commented Jan 9, 2025

Yes, to add to @tacaswell advice, you can probably save us quite a bit of time by minimizing as much as possible - eg removing everything except exactly what causes the problem. Please see http://www.sscce.org for why this is worthwhile.

@busFred
Copy link
Author

busFred commented Jan 10, 2025

code using random generated data that does not involve loading pickle and can be run immediately after copy and paste is at the end of the comment.

Based on suggestions from @tacaswell, here are some possible combination of figure layout and savefig bbox. It seems constrained layout inf figure plays significant role in keeping subfigures not overlap with each other, while tight_layout during savefig seems crucial to keep fig.supxlabel and legend box not overlap with each other.

  • plf.figure(layout='constrained') + fig.savefig()
    alpha0 0_ts_reward_constrained
  • plf.figure(layout='constrained') + fig.savefig(bbox_inches="tight")
    alpha0 0_ts_reward_constrained_tight
  • plf.figure() + fig.savefig()
    alpha0 0_ts_reward
  • plf.figure() + fig.savefig(bbox_inches="tight")
    alpha0 0_ts_reward_tight
# %%
import math
import os
from typing import Sequence

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.axes import Axes
from matplotlib.figure import Figure, SubFigure
from matplotlib.ticker import FormatStrFormatter

# %%
alpha: float = 0.0
N_COLS: int = 3
SIMPLE_STRAT = "ts"


# %%
out_p = os.path.join("outputs")
os.makedirs(out_p, exist_ok=True)


# %%
STRATS = ["mtspm", "ts", "ei", "revi", "random", "modiste", "spanner"]
ESTIMATORS = ["xgb", "structure", "knn", "uknn", ""]


# %%
STRAT_EST_TO_PLOT_KWARGS = {
    f"{SIMPLE_STRAT}_xgb": {
        "label": "XGB-Bootstrap",
        "color": "blue",
        "marker": "x",
    },
    f"{SIMPLE_STRAT}_structure": {
        "label": "Mimic-Bootstrap",
        "color": "red",
        "marker": "v",
    },
    "spanner": {
        "label": "SpannerIGW",
        "color": "gold",
        "marker": "h",
    },
    "modiste_knn": {
        "label": "Modiste-KNN",
        "color": "darkgreen",
        "marker": "+",
    },
    "modiste_uknn": {
        "label": "Modiste-UKNN",
        "color": "lime",
        "marker": "o",
    },
}

# %%
results = {
    f"subplot {idx}": {
        "ts_xgb": [
            pd.DataFrame(
                {
                    "train/n_obs": np.arange(0, 9000, 1000),
                    "val/reward": np.random.randn(9),
                }
            )
        ],
        "ts_structure": [
            pd.DataFrame(
                {
                    "train/n_obs": np.arange(0, 9000, 1000),
                    "val/reward": np.random.randn(9),
                }
            )
        ],
        "spanner": [
            pd.DataFrame(
                {
                    "train/n_obs": np.arange(0, 9000, 1000),
                    "val/reward": np.random.randn(9),
                }
            )
        ],
        "modiste_knn": [
            pd.DataFrame(
                {
                    "train/n_obs": np.arange(0, 9000, 1000),
                    "val/reward": np.random.randn(9),
                }
            )
        ],
        "modiste_uknn": [
            pd.DataFrame(
                {
                    "train/n_obs": np.arange(0, 9000, 1000),
                    "val/reward": np.random.randn(9),
                }
            )
        ],
    }
    for idx in range(5)
}


# %%
def make_plot(key: str, supylabel: str, fn_post: str):
    fig: Figure = plt.figure(layout="constrained")
    n_rows: int = math.ceil(len(results) / N_COLS)
    fig.set_figheight(1.7 * n_rows)
    xl = fig.supxlabel("observations")
    yl = fig.supylabel(supylabel)
    subfigs: Sequence[SubFigure] = fig.subfigures(nrows=n_rows, ncols=1)  # type:ignore
    label_set = list()
    line_set = list()
    results_l = [(k, v) for k, v in results.items()]
    end_idx: int = 0
    for _fig in subfigs:
        start_idx: int = end_idx
        end_idx: int = min(end_idx + N_COLS, len(results_l))
        _results = results_l[start_idx:end_idx]
        axs: Sequence[Axes] = _fig.subplots(
            nrows=1, ncols=len(_results), squeeze=False
        )[0]
        for ax, (name, strat_est_to_metrics_d) in zip(axs, _results):
            ax.set_title(name)
            ax.set_box_aspect(1.0)
            ax.set_xlim(0.0, 8500)
            ax.set_xticks([500, 4000, 8000])
            ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
            for strat_est, metrics in strat_est_to_metrics_d.items():
                _metrics_d: pd.DataFrame = metrics[0][[key, "train/n_obs"]].dropna()
                ax.plot(
                    _metrics_d["train/n_obs"],
                    _metrics_d[key],
                    **STRAT_EST_TO_PLOT_KWARGS[strat_est],
                )
            _lines, _labels = ax.get_legend_handles_labels()
            for _line, _label in zip(_lines, _labels):
                if _label not in label_set:
                    line_set.append(_line)
                    label_set.append(_label)
    lg = fig.legend(
        line_set,
        label_set,
        loc="lower center",
        borderaxespad=-1.3,
        # bbox_to_anchor=(),
        frameon=True,
        ncol=3,
    )
    plt.savefig(
        os.path.join(
            out_p, f"alpha{alpha}_{SIMPLE_STRAT}_{fn_post}_constrained_tight.png"
        ),
        dpi=720,
        bbox_extra_artists=[xl, yl, lg],
        bbox_inches="tight",
    )
    plt.savefig(
        os.path.join(out_p, f"alpha{alpha}_{SIMPLE_STRAT}_{fn_post}_constrained.png"),
        bbox_extra_artists=[xl, yl, lg],
    )
    plt.show()
    plt.close()


make_plot("val/reward", "held-out reward", "reward")


# %%
def make_plot_not_constrained(key: str, supylabel: str, fn_post: str):
    fig: Figure = plt.figure()
    n_rows: int = math.ceil(len(results) / N_COLS)
    fig.set_figheight(1.7 * n_rows)
    xl = fig.supxlabel("observations")
    yl = fig.supylabel(supylabel)
    subfigs: Sequence[SubFigure] = fig.subfigures(nrows=n_rows, ncols=1)  # type:ignore
    label_set = list()
    line_set = list()
    results_l = [(k, v) for k, v in results.items()]
    end_idx: int = 0
    for _fig in subfigs:
        start_idx: int = end_idx
        end_idx: int = min(end_idx + N_COLS, len(results_l))
        _results = results_l[start_idx:end_idx]
        axs: Sequence[Axes] = _fig.subplots(
            nrows=1, ncols=len(_results), squeeze=False
        )[0]
        for ax, (name, strat_est_to_metrics_d) in zip(axs, _results):
            ax.set_title(name)
            ax.set_box_aspect(1.0)
            ax.set_xlim(0.0, 8500)
            ax.set_xticks([500, 4000, 8000])
            ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
            for strat_est, metrics in strat_est_to_metrics_d.items():
                _metrics_d: pd.DataFrame = metrics[0][[key, "train/n_obs"]].dropna()
                ax.plot(
                    _metrics_d["train/n_obs"],
                    _metrics_d[key],
                    **STRAT_EST_TO_PLOT_KWARGS[strat_est],
                )
            _lines, _labels = ax.get_legend_handles_labels()
            for _line, _label in zip(_lines, _labels):
                if _label not in label_set:
                    line_set.append(_line)
                    label_set.append(_label)
    lg = fig.legend(
        line_set,
        label_set,
        loc="lower center",
        borderaxespad=-1.3,
        # bbox_to_anchor=(),
        frameon=True,
        ncol=3,
    )
    plt.savefig(
        os.path.join(out_p, f"alpha{alpha}_{SIMPLE_STRAT}_{fn_post}_tight.png"),
        dpi=720,
        bbox_extra_artists=[xl, yl, lg],
        bbox_inches="tight",
    )
    plt.savefig(
        os.path.join(out_p, f"alpha{alpha}_{SIMPLE_STRAT}_{fn_post}.png"),
        bbox_extra_artists=[xl, yl, lg],
    )
    plt.show()
    plt.close()


make_plot_not_constrained("val/reward", "held-out reward", "reward")

# %%

@jklymak
Copy link
Member

jklymak commented Jan 10, 2025

fig.legend doesn't participate in constrained_layout at all, so this is basically working as expected. This can be trivially reproduced with

fig, ax = plt.subplots(layout='constrained')
# fig.get_layout_engine().set(rect=[0, 0.1, 1, 0.9])
ax.plot(np.random.rand(10), label='boo')
fig.legend(loc='lower center')

plt.show()

If you want to save some space at the bottom for a legend, you can uncomment the second line and manually adjust the 0.1 and 0.9.

Automagically taking fig.legend into account for constrained_layout would be pretty complicated, and while I wouldn't object if someone wanted to do so in a clean way, I don't think it would be good cost/benefit.

@jakelevi1996
Copy link

jakelevi1996 commented Jan 23, 2025

@jklymak I seem to be able to fix your last example simply by replacing loc='lower center' with loc='outside lower center' !?

import numpy as np
import matplotlib.pyplot as plt
rng = np.random.default_rng(0)

fig, ax = plt.subplots(layout='constrained')
ax.plot(rng.random(10), label='boo')
fig.legend(loc='outside lower center')

fig.savefig(".temp.png")

Image

Version information: Python 3.10.12, matplotlib 3.9.2, Ubuntu 22.04.1 LTS

@jakelevi1996
Copy link

@busFred TLDR: do use loc='outside lower center', don't use borderaxespad=-1.3.


Here is an example including "jagged subplots" (which you said was your use case), which works fine for me:

import numpy as np
import matplotlib.pyplot as plt
rng = np.random.default_rng(0)

fig = plt.figure(layout='constrained')
subfigs = fig.subfigures(nrows=2, ncols=1)
ax1 = subfigs[0].subplots(nrows=1, ncols=3)
ax2 = subfigs[1].subplots(nrows=1, ncols=2)

ax1[0].plot(rng.random(10), label='boo')
ax2[1].plot(rng.random(10), label='boo 2', c="r")

fig.legend(loc='outside lower center', ncols=2)
# fig.legend(loc='outside lower center', ncols=2, borderaxespad=-1.3)

fig.savefig(".temp.png")

Image

If I include the borderaxespad=-1.3 keyword argument to fig.legend that you included in your original example, then the legend moves partly out of the canvas:

Image

@jakelevi1996
Copy link

This is documented in the Figure legends section of the Legend guide:

By using constrained layout and specifying "outside" at the beginning of the loc keyword argument, the legend is drawn outside the Axes on the (sub)figure.

See also matplotlib.figure.Figure.legend:

If a figure is using the constrained layout manager, the string codes of the loc keyword argument can get better layout behaviour using the prefix 'outside'. There is ambiguity at the corners, so 'outside upper right' will make space for the legend above the rest of the axes in the layout, and 'outside right upper' will make space on the right side of the layout. In addition to the values of loc listed above, we have 'outside right upper', 'outside right lower', 'outside left upper', and 'outside left lower'. See Legend guide for more details.

@jklymak
Copy link
Member

jklymak commented Jan 23, 2025

Ooops, yes I forgot about that! Thanks @jakelevi1996 !

@rcomer
Copy link
Member

rcomer commented May 12, 2025

This one has been quiet for a while, but it looks like the solution was to use loc='outside lower center', so closing here.

@rcomer rcomer closed this as completed May 12, 2025
@rcomer rcomer added Community support Users in need of help. and removed status: needs clarification Issues that need more information to resolve. labels May 12, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Community support Users in need of help.
Projects
None yet
Development

No branches or pull requests

5 participants