Skip to content

Commit 2236a2c

Browse files
committed
Clip RGB data to vaild range in Axes.imshow
1 parent 7ee6499 commit 2236a2c

File tree

6 files changed

+59
-7
lines changed

6 files changed

+59
-7
lines changed

doc/users/credits.rst

+1
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,7 @@ Yu Feng,
386386
Yunfei Yang,
387387
Yuri D'Elia,
388388
Yuval Langer,
389+
Zac Hatfield-Dodds,
389390
Zach Pincus,
390391
Zair Mubashar,
391392
alex,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
`Axes.imshow` clips RGB values to the valid range
2+
-------------------------------------------------
3+
4+
When `Axes.imshow` is passed an RGB or RGBA value with out-of-range
5+
values, it now issues a warning and clips them to the valid range.
6+
The old behaviour, wrapping back in to the range, often hid outliers
7+
and made interpreting RGB images unreliable.

lib/matplotlib/axes/_axes.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -5241,10 +5241,14 @@ def imshow(self, X, cmap=None, norm=None, aspect=None,
52415241
- MxNx3 -- RGB (float or uint8)
52425242
- MxNx4 -- RGBA (float or uint8)
52435243
5244-
The value for each component of MxNx3 and MxNx4 float arrays
5245-
should be in the range 0.0 to 1.0. MxN arrays are mapped
5246-
to colors based on the `norm` (mapping scalar to scalar)
5247-
and the `cmap` (mapping the normed scalar to a color).
5244+
MxN arrays are mapped to colors based on the `norm` (mapping
5245+
scalar to scalar) and the `cmap` (mapping the normed scalar to
5246+
a color).
5247+
5248+
Elements of RGB and RGBA arrays represent pixels of an MxN image.
5249+
All values should be in the range [0 .. 1] for floats or
5250+
[0 .. 255] for integers. Out-of-range values will be clipped to
5251+
these bounds.
52485252
52495253
cmap : `~matplotlib.colors.Colormap`, optional, default: None
52505254
If None, default to rc `image.cmap` value. `cmap` is ignored
@@ -5286,7 +5290,8 @@ def imshow(self, X, cmap=None, norm=None, aspect=None,
52865290
settings for `vmin` and `vmax` will be ignored.
52875291
52885292
alpha : scalar, optional, default: None
5289-
The alpha blending value, between 0 (transparent) and 1 (opaque)
5293+
The alpha blending value, between 0 (transparent) and 1 (opaque).
5294+
The ``alpha`` argument is ignored for RGBA input data.
52905295
52915296
origin : ['upper' | 'lower'], optional, default: None
52925297
Place the [0,0] index of the array in the upper left or lower left

lib/matplotlib/cm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
259259
xx = (xx * 255).astype(np.uint8)
260260
elif xx.dtype == np.uint8:
261261
if not bytes:
262-
xx = xx.astype(float) / 255
262+
xx = xx.astype(np.float32) / 255
263263
else:
264264
raise ValueError("Image RGB array must be uint8 or "
265265
"floating point; found %s" % xx.dtype)

lib/matplotlib/image.py

+20
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
from math import ceil
1515
import os
16+
import logging
1617

1718
import numpy as np
1819

@@ -34,6 +35,8 @@
3435
from matplotlib.transforms import (Affine2D, BboxBase, Bbox, BboxTransform,
3536
IdentityTransform, TransformedBbox)
3637

38+
_log = logging.getLogger(__name__)
39+
3740
# map interpolation strings to module constants
3841
_interpd_ = {
3942
'none': _image.NEAREST, # fall back to nearest when not supported
@@ -621,6 +624,23 @@ def set_data(self, A):
621624
or self._A.ndim == 3 and self._A.shape[-1] in [3, 4]):
622625
raise TypeError("Invalid dimensions for image data")
623626

627+
if self._A.ndim == 3:
628+
# If the input data has values outside the valid range (after
629+
# normalisation), we issue a warning and then clip X to the bounds
630+
# - otherwise casting wraps extreme values, hiding outliers and
631+
# making reliable interpretation impossible.
632+
high = 255 if np.issubdtype(self._A.dtype, np.integer) else 1
633+
if self._A.min() < 0 or high < self._A.max():
634+
_log.warning(
635+
'Clipping input data to the valid range for imshow with '
636+
'RGB data ([0..1] for floats or [0..255] for integers).'
637+
)
638+
self._A = np.clip(self._A, 0, high)
639+
# Cast unsupported integer types to uint8
640+
if self._A.dtype != np.uint8 and np.issubdtype(self._A.dtype,
641+
np.integer):
642+
self._A = self._A.astype(np.uint8)
643+
624644
self._imcache = None
625645
self._rgbacache = None
626646
self.stale = True

lib/matplotlib/tests/test_image.py

+20-1
Original file line numberDiff line numberDiff line change
@@ -602,7 +602,7 @@ def test_minimized_rasterized():
602602
def test_load_from_url():
603603
req = six.moves.urllib.request.urlopen(
604604
"http://matplotlib.org/_static/logo_sidebar_horiz.png")
605-
Z = plt.imread(req)
605+
plt.imread(req)
606606

607607

608608
@image_comparison(baseline_images=['log_scale_image'],
@@ -795,6 +795,25 @@ def test_imshow_no_warn_invalid():
795795
assert len(warns) == 0
796796

797797

798+
@pytest.mark.parametrize(
799+
'dtype', [np.dtype(s) for s in 'u2 u4 i2 i4 i8 f4 f8'.split()])
800+
def test_imshow_clips_rgb_to_valid_range(dtype):
801+
arr = np.arange(300, dtype=dtype).reshape((10, 10, 3))
802+
if dtype.kind != 'u':
803+
arr -= 10
804+
if dtype.kind == 'f':
805+
arr = arr / 255
806+
_, ax = plt.subplots()
807+
out = ax.imshow(arr).get_array()
808+
assert out.min() == 0
809+
if dtype.kind == 'f':
810+
assert out.max() == 1
811+
assert out.dtype.kind == 'f'
812+
else:
813+
assert out.max == 255
814+
assert out.dtype == np.uint8
815+
816+
798817
@image_comparison(baseline_images=['imshow_flatfield'],
799818
remove_text=True, style='mpl20',
800819
extensions=['png'])

0 commit comments

Comments
 (0)