Skip to content

Commit 010f2f4

Browse files
authored
Merge pull request #13263 from anntzer/dont-go-through-figure-in-imsave
MNT: In imsave()'s Pillow-handled case, don't create a temporary figure.
2 parents 4d044e0 + cb7f555 commit 010f2f4

File tree

2 files changed

+52
-19
lines changed

2 files changed

+52
-19
lines changed

lib/matplotlib/image.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from math import ceil
88
import os
99
import logging
10+
from pathlib import Path
1011
import urllib.parse
1112
import urllib.request
1213

@@ -1430,24 +1431,48 @@ def imsave(fname, arr, vmin=None, vmax=None, cmap=None, format=None,
14301431
The DPI to store in the metadata of the file. This does not affect the
14311432
resolution of the output image.
14321433
"""
1433-
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
14341434
from matplotlib.figure import Figure
14351435
if isinstance(fname, os.PathLike):
14361436
fname = os.fspath(fname)
1437-
if (format == 'png'
1438-
or (format is None
1439-
and isinstance(fname, str)
1440-
and fname.lower().endswith('.png'))):
1441-
image = AxesImage(None, cmap=cmap, origin=origin)
1442-
image.set_data(arr)
1443-
image.set_clim(vmin, vmax)
1444-
image.write_png(fname)
1445-
else:
1437+
if format is None:
1438+
format = (Path(fname).suffix[1:] if isinstance(fname, str)
1439+
else rcParams["savefig.format"]).lower()
1440+
if format in ["pdf", "ps", "eps", "svg"]:
1441+
# Vector formats that are not handled by PIL.
14461442
fig = Figure(dpi=dpi, frameon=False)
1447-
FigureCanvas(fig)
14481443
fig.figimage(arr, cmap=cmap, vmin=vmin, vmax=vmax, origin=origin,
14491444
resize=True)
14501445
fig.savefig(fname, dpi=dpi, format=format, transparent=True)
1446+
else:
1447+
# Don't bother creating an image; this avoids rounding errors on the
1448+
# size when dividing and then multiplying by dpi.
1449+
sm = cm.ScalarMappable(cmap=cmap)
1450+
sm.set_clim(vmin, vmax)
1451+
if origin is None:
1452+
origin = rcParams["image.origin"]
1453+
if origin == "lower":
1454+
arr = arr[::-1]
1455+
rgba = sm.to_rgba(arr, bytes=True)
1456+
if format == "png":
1457+
_png.write_png(rgba, fname, dpi=dpi)
1458+
else:
1459+
try:
1460+
from PIL import Image
1461+
except ImportError as exc:
1462+
raise ImportError(
1463+
f"Saving to {format} requires Pillow") from exc
1464+
pil_shape = (rgba.shape[1], rgba.shape[0])
1465+
image = Image.frombuffer(
1466+
"RGBA", pil_shape, rgba, "raw", "RGBA", 0, 1)
1467+
if format in ["jpg", "jpeg"]:
1468+
format = "jpeg" # Pillow doesn't recognize "jpg".
1469+
color = tuple(
1470+
int(x * 255)
1471+
for x in mcolors.to_rgb(rcParams["savefig.facecolor"]))
1472+
background = Image.new("RGB", pil_shape, color)
1473+
background.paste(image, image)
1474+
image = background
1475+
image.save(fname, format=format, dpi=(dpi, dpi))
14511476

14521477

14531478
def pil_to_array(pilImage):

lib/matplotlib/tests/test_image.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,12 @@ def test_imread_fspath():
121121
assert np.sum(img) == 134184960
122122

123123

124-
def test_imsave():
124+
@pytest.mark.parametrize("fmt", ["png", "jpg", "jpeg", "tiff"])
125+
def test_imsave(fmt):
126+
if fmt in ["jpg", "jpeg", "tiff"]:
127+
pytest.importorskip("PIL")
128+
has_alpha = fmt not in ["jpg", "jpeg"]
129+
125130
# The goal here is that the user can specify an output logical DPI
126131
# for the image, but this will not actually add any extra pixels
127132
# to the image, it will merely be used for metadata purposes.
@@ -130,22 +135,25 @@ def test_imsave():
130135
# == 100) and read the resulting PNG files back in and make sure
131136
# the data is 100% identical.
132137
np.random.seed(1)
133-
data = np.random.rand(256, 128)
138+
# The height of 1856 pixels was selected because going through creating an
139+
# actual dpi=100 figure to save the image to a Pillow-provided format would
140+
# cause a rounding error resulting in a final image of shape 1855.
141+
data = np.random.rand(1856, 2)
134142

135143
buff_dpi1 = io.BytesIO()
136-
plt.imsave(buff_dpi1, data, dpi=1)
144+
plt.imsave(buff_dpi1, data, format=fmt, dpi=1)
137145

138146
buff_dpi100 = io.BytesIO()
139-
plt.imsave(buff_dpi100, data, dpi=100)
147+
plt.imsave(buff_dpi100, data, format=fmt, dpi=100)
140148

141149
buff_dpi1.seek(0)
142-
arr_dpi1 = plt.imread(buff_dpi1)
150+
arr_dpi1 = plt.imread(buff_dpi1, format=fmt)
143151

144152
buff_dpi100.seek(0)
145-
arr_dpi100 = plt.imread(buff_dpi100)
153+
arr_dpi100 = plt.imread(buff_dpi100, format=fmt)
146154

147-
assert arr_dpi1.shape == (256, 128, 4)
148-
assert arr_dpi100.shape == (256, 128, 4)
155+
assert arr_dpi1.shape == (1856, 2, 3 + has_alpha)
156+
assert arr_dpi100.shape == (1856, 2, 3 + has_alpha)
149157

150158
assert_array_equal(arr_dpi1, arr_dpi100)
151159

0 commit comments

Comments
 (0)