From 5873efb40dbd3a304edc252c83801534cea42060 Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Sat, 25 May 2019 23:50:16 +0200 Subject: [PATCH] Correctly apply PNG palette when building ImageBase through Pillow. --- lib/matplotlib/image.py | 9 ++++++++- lib/matplotlib/tests/test_image.py | 17 +++++++++++------ 2 files changed, 19 insertions(+), 7 deletions(-) diff --git a/lib/matplotlib/image.py b/lib/matplotlib/image.py index 527a0597505d..3f1c25f79e94 100644 --- a/lib/matplotlib/image.py +++ b/lib/matplotlib/image.py @@ -665,8 +665,15 @@ def set_data(self, A): Parameters ---------- - A : array-like + A : array-like or `PIL.Image.Image` """ + try: + from PIL import Image + except ImportError: + pass + else: + if isinstance(A, Image.Image): + A = pil_to_array(A) # Needed e.g. to apply png palette. self._A = cbook.safe_masked_invalid(A, copy=True) if (self._A.dtype != np.uint8 and diff --git a/lib/matplotlib/tests/test_image.py b/lib/matplotlib/tests/test_image.py index afdee7f98023..065de4ee207d 100644 --- a/lib/matplotlib/tests/test_image.py +++ b/lib/matplotlib/tests/test_image.py @@ -14,7 +14,7 @@ from numpy.testing import assert_array_equal from matplotlib import ( - colors, image as mimage, patches, pyplot as plt, + colors, image as mimage, patches, pyplot as plt, style, rc_context, rcParams) from matplotlib.cbook import MatplotlibDeprecationWarning from matplotlib.image import (AxesImage, BboxImage, FigureImage, @@ -117,11 +117,16 @@ def test_image_python_io(): @check_figures_equal() def test_imshow_pil(fig_test, fig_ref): - pytest.importorskip("PIL") - img = plt.imread(os.path.join(os.path.dirname(__file__), - 'baseline_images', 'test_image', 'uint16.tif')) - fig_test.subplots().imshow(img) - fig_ref.subplots().imshow(np.asarray(img)) + style.use("default") + PIL = pytest.importorskip("PIL") + png_path = Path(__file__).parent / "baseline_images/pngsuite/basn3p04.png" + tiff_path = Path(__file__).parent / "baseline_images/test_image/uint16.tif" + axs = fig_test.subplots(2) + axs[0].imshow(PIL.Image.open(png_path)) + axs[1].imshow(PIL.Image.open(tiff_path)) + axs = fig_ref.subplots(2) + axs[0].imshow(plt.imread(str(png_path))) + axs[1].imshow(plt.imread(tiff_path)) def test_imread_pil_uint16():