Skip to content

Commit 89db355

Browse files
authored
Merge pull request #14326 from anntzer/pillowpngpalette
Correctly apply PNG palette when building ImageBase through Pillow.
2 parents 30354dd + 5873efb commit 89db355

File tree

2 files changed

+19
-7
lines changed

2 files changed

+19
-7
lines changed

lib/matplotlib/image.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,8 +633,15 @@ def set_data(self, A):
633633
634634
Parameters
635635
----------
636-
A : array-like
636+
A : array-like or `PIL.Image.Image`
637637
"""
638+
try:
639+
from PIL import Image
640+
except ImportError:
641+
pass
642+
else:
643+
if isinstance(A, Image.Image):
644+
A = pil_to_array(A) # Needed e.g. to apply png palette.
638645
self._A = cbook.safe_masked_invalid(A, copy=True)
639646

640647
if (self._A.dtype != np.uint8 and

lib/matplotlib/tests/test_image.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from numpy.testing import assert_array_equal
1515

1616
from matplotlib import (
17-
colors, image as mimage, patches, pyplot as plt,
17+
colors, image as mimage, patches, pyplot as plt, style,
1818
rc_context, rcParams)
1919
from matplotlib.cbook import MatplotlibDeprecationWarning
2020
from matplotlib.image import (AxesImage, BboxImage, FigureImage,
@@ -117,11 +117,16 @@ def test_image_python_io():
117117

118118
@check_figures_equal()
119119
def test_imshow_pil(fig_test, fig_ref):
120-
pytest.importorskip("PIL")
121-
img = plt.imread(os.path.join(os.path.dirname(__file__),
122-
'baseline_images', 'test_image', 'uint16.tif'))
123-
fig_test.subplots().imshow(img)
124-
fig_ref.subplots().imshow(np.asarray(img))
120+
style.use("default")
121+
PIL = pytest.importorskip("PIL")
122+
png_path = Path(__file__).parent / "baseline_images/pngsuite/basn3p04.png"
123+
tiff_path = Path(__file__).parent / "baseline_images/test_image/uint16.tif"
124+
axs = fig_test.subplots(2)
125+
axs[0].imshow(PIL.Image.open(png_path))
126+
axs[1].imshow(PIL.Image.open(tiff_path))
127+
axs = fig_ref.subplots(2)
128+
axs[0].imshow(plt.imread(str(png_path)))
129+
axs[1].imshow(plt.imread(tiff_path))
125130

126131

127132
def test_imread_pil_uint16():

0 commit comments

Comments
 (0)