Skip to content

In imsave()'s Pillow-handled case, don't create a temporary figure. #13263

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions lib/matplotlib/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from math import ceil
import os
import logging
from pathlib import Path
import urllib.parse
import urllib.request

Expand Down Expand Up @@ -1432,24 +1433,48 @@ def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None,
The DPI to store in the metadata of the file. This does not affect the
resolution of the output image.
"""
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure
if isinstance(fname, os.PathLike):
fname = os.fspath(fname)
if (format == 'png'
or (format is None
and isinstance(fname, str)
and fname.lower().endswith('.png'))):
image = AxesImage(None, cmap=cmap, origin=origin)
image.set_data(arr)
image.set_clim(vmin, vmax)
image.write_png(fname)
else:
if format is None:
format = (Path(fname).suffix[1:] if isinstance(fname, str)
else rcParams["savefig.format"]).lower()
if format in ["pdf", "ps", "eps", "svg"]:
# Vector formats that are not handled by PIL.
fig = Figure(dpi=dpi, frameon=False)
FigureCanvas(fig)
fig.figimage(arr, cmap=cmap, vmin=vmin, vmax=vmax, origin=origin,
resize=True)
fig.savefig(fname, dpi=dpi, format=format, transparent=True)
else:
# Don't bother creating an image; this avoids rounding errors on the
# size when dividing and then multiplying by dpi.
sm = cm.ScalarMappable(cmap=cmap)
sm.set_clim(vmin, vmax)
if origin is None:
origin = rcParams["image.origin"]
if origin == "lower":
arr = arr[::-1]
rgba = sm.to_rgba(arr, bytes=True)
if format == "png":
_png.write_png(rgba, fname, dpi=dpi)
else:
try:
from PIL import Image
except ImportError as exc:
raise ImportError(
f"Saving to {format} requires Pillow") from exc
pil_shape = (rgba.shape[1], rgba.shape[0])
image = Image.frombuffer(
"RGBA", pil_shape, rgba, "raw", "RGBA", 0, 1)
if format in ["jpg", "jpeg"]:
format = "jpeg" # Pillow doesn't recognize "jpg".
color = tuple(
int(x * 255)
for x in mcolors.to_rgb(rcParams["savefig.facecolor"]))
background = Image.new("RGB", pil_shape, color)
background.paste(image, image)
image = background
image.save(fname, format=format, dpi=(dpi, dpi))


def pil_to_array(pilImage):
Expand Down
24 changes: 16 additions & 8 deletions lib/matplotlib/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,12 @@ def test_imread_fspath():
assert np.sum(img) == 134184960


def test_imsave():
@pytest.mark.parametrize("fmt", ["png", "jpg", "jpeg", "tiff"])
def test_imsave(fmt):
if fmt in ["jpg", "jpeg", "tiff"]:
pytest.importorskip("PIL")
has_alpha = fmt not in ["jpg", "jpeg"]

# The goal here is that the user can specify an output logical DPI
# for the image, but this will not actually add any extra pixels
# to the image, it will merely be used for metadata purposes.
Expand All @@ -130,22 +135,25 @@ def test_imsave():
# == 100) and read the resulting PNG files back in and make sure
# the data is 100% identical.
np.random.seed(1)
data = np.random.rand(256, 128)
# The height of 1856 pixels was selected because going through creating an
# actual dpi=100 figure to save the image to a Pillow-provided format would
# cause a rounding error resulting in a final image of shape 1855.
data = np.random.rand(1856, 2)

buff_dpi1 = io.BytesIO()
plt.imsave(buff_dpi1, data, dpi=1)
plt.imsave(buff_dpi1, data, format=fmt, dpi=1)

buff_dpi100 = io.BytesIO()
plt.imsave(buff_dpi100, data, dpi=100)
plt.imsave(buff_dpi100, data, format=fmt, dpi=100)

buff_dpi1.seek(0)
arr_dpi1 = plt.imread(buff_dpi1)
arr_dpi1 = plt.imread(buff_dpi1, format=fmt)

buff_dpi100.seek(0)
arr_dpi100 = plt.imread(buff_dpi100)
arr_dpi100 = plt.imread(buff_dpi100, format=fmt)

assert arr_dpi1.shape == (256, 128, 4)
assert arr_dpi100.shape == (256, 128, 4)
assert arr_dpi1.shape == (1856, 2, 3 + has_alpha)
assert arr_dpi100.shape == (1856, 2, 3 + has_alpha)

assert_array_equal(arr_dpi1, arr_dpi100)

Expand Down