Closed
Description
Bug summary
I created a plot with fig: Figure = plt.figure(layout="constrained")
and add several rows of SubFigure
with potentially unequal amount of columns using fig.subfigures
. Since all the boxes in this parent figure have lines sharing same line style, I added a global legend box with fig.legend
. When I save the figure with fig.savefig
, legend box just get cropped off regardless of how i set things.
Code for reproduction
def plot_reward():
fig: Figure = plt.figure()
fig: Figure = plt.figure(layout="constrained")
n_rows: int = math.ceil(len(results) / N_COLS)
fig.set_figheight(1.7 * n_rows)
fig.supxlabel("observations")
fig.supylabel("held-out rewards")
subfigs: Sequence[SubFigure] = fig.subfigures(nrows=n_rows, ncols=1) # type:ignore
label_set = list()
line_set = list()
results_l = [(k, v) for k, v in results.items()]
end_idx: int = 0
for _fig in subfigs:
start_idx: int = end_idx
end_idx: int = min(end_idx + N_COLS, len(results_l))
_results = results_l[start_idx:end_idx]
axs: Sequence[Axes] = _fig.subplots(
nrows=1, ncols=len(_results), squeeze=False
)[0]
for ax, (name, strat_est_to_metrics_d) in zip(axs, _results):
ax.set_title(name)
ax.set_box_aspect(1.0)
ax.set_xlim(0.0, 8500)
ax.set_xticks([500, 4000, 8000])
ax.yaxis.set_major_formatter(FormatStrFormatter("%.2f"))
for strat_est, metrics in strat_est_to_metrics_d.items():
if len(metrics) > 1:
pass
else:
_metrics_d: pd.DataFrame = metrics[0][
["val/reward", "train/n_obs"]
].dropna()
ax.plot(
_metrics_d["train/n_obs"],
_metrics_d["val/reward"],
**STRAT_EST_TO_PLOT_KWARGS[strat_est],
)
_lines, _labels = ax.get_legend_handles_labels()
for _line, _label in zip(_lines, _labels):
if _label not in label_set:
line_set.append(_line)
label_set.append(_label)
fig.legend(
line_set,
label_set,
loc="lower center",
borderaxespad=-1.3,
# bbox_to_anchor=(),
frameon=True,
ncol=3,
)
print(fig.get_default_bbox_extra_artists())
fn: str = os.path.splitext(cfg_p.split("/")[-1])[0]
plt.savefig(
os.path.join(out_p, f"{fn}_alpha{alpha}_reward.png"),
dpi=720,
bbox_extra_artists=fig.get_default_bbox_extra_artists(),
bbox_inches="tight",
)
plt.savefig(
os.path.join(out_p, f"{fn}_alpha{alpha}_reward.svg"),
bbox_extra_artists=fig.get_default_bbox_extra_artists(),
bbox_inches="tight",
)
plt.show()
plt.close()
plot_reward()
Actual outcome
Expected outcome
Both rows in the legend box should be saved but the second row is always cropped off regardless.
Additional information
result of print(fig.get_default_bbox_extra_artists())
seems to contain legend box. Just for whatever reason bbox_inches='tight' isn't considering it properly.
[Text(0.5, 0.01, 'observations'), Text(0.02, 0.5, 'held-out rewards'), <matplotlib.legend.Legend object at 0x78abd7a28df0>, <matplotlib.figure.SubFigure object at 0x78abd7a28550>, <matplotlib.figure.SubFigure object at 0x78abd7a2bdf0>, <matplotlib.spines.Spine object at 0x78abd7a2ad10>, <matplotlib.spines.Spine object at 0x78abd7a28970>, <matplotlib.spines.Spine object at 0x78abd7a285e0>, <matplotlib.spines.Spine object at 0x78abd7a29de0>, <matplotlib.patches.Rectangle object at 0x78abcc2b4fa0>, <matplotlib.spines.Spine object at 0x78abcc2b4d90>, <matplotlib.spines.Spine object at 0x78abcc2b4940>, <matplotlib.spines.Spine object at 0x78abcc2b4730>, <matplotlib.spines.Spine object at 0x78abcc2b45e0>, <matplotlib.patches.Rectangle object at 0x78abd7fc9c90>, <matplotlib.spines.Spine object at 0x78abd7fc80d0>, <matplotlib.spines.Spine object at 0x78abd7fca170>, <matplotlib.spines.Spine object at 0x78abd7fc95a0>, <matplotlib.spines.Spine object at 0x78abd7fc9690>, <matplotlib.patches.Rectangle object at 0x78abd7a748b0>, <matplotlib.spines.Spine object at 0x78abccd386d0>, <matplotlib.spines.Spine object at 0x78abccd3b3a0>, <matplotlib.spines.Spine object at 0x78abccd39d50>, <matplotlib.spines.Spine object at 0x78abccd3b430>, <matplotlib.patches.Rectangle object at 0x78abcc231bd0>, <matplotlib.spines.Spine object at 0x78abcc230b20>, <matplotlib.spines.Spine object at 0x78abcc230400>, <matplotlib.spines.Spine object at 0x78abcc2310c0>, <matplotlib.spines.Spine object at 0x78abcc230f40>, <matplotlib.patches.Rectangle object at 0x78abccd730d0>, <matplotlib.spines.Spine object at 0x78abcc3b2710>, <matplotlib.spines.Spine object at 0x78abcc3b27a0>, <matplotlib.spines.Spine object at 0x78abcc3b2230>, <matplotlib.spines.Spine object at 0x78abcc3b3e50>, <matplotlib.patches.Rectangle object at 0x78abcc3b1180>]
Operating system
Linux Mint 22
Matplotlib Version
3.10.0
Matplotlib Backend
print(matplotlib.get_backend())
Python version
3.10.14
Jupyter version
No response
Installation
pip