Skip to content

Commit dd94ac3

Browse files
committed
FIX: handle nans in RGBA input with ScalarMappables
1 parent 0da2da0 commit dd94ac3

File tree

2 files changed

+33
-1
lines changed

2 files changed

+33
-1
lines changed

lib/matplotlib/cm.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,7 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
328328
treated as an RGB or RGBA array, and no mapping will be done.
329329
The array can be `~numpy.uint8`, or it can be floats with
330330
values in the 0-1 range; otherwise a ValueError will be raised.
331-
If it is a masked array, any masked elements will be set to 0 alpha.
331+
Any NaNs or masked elements will be set to 0 alpha.
332332
If the last dimension is 3, the *alpha* kwarg (defaulting to 1)
333333
will be used to fill in the transparency. If the last dimension
334334
is 4, the *alpha* kwarg is ignored; it does not
@@ -360,6 +360,12 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
360360
else:
361361
raise ValueError("Third dimension must be 3 or 4")
362362
if xx.dtype.kind == 'f':
363+
# If any of R, G, B, or A is nan, set to 0
364+
if np.any(nans := np.isnan(x)):
365+
if xx.shape[2] == 4:
366+
xx = xx.copy()
367+
xx[np.any(nans, axis=2), :] = 0
368+
363369
if norm and (xx.max() > 1 or xx.min() < 0):
364370
raise ValueError("Floating point image RGB values "
365371
"must be in the 0..1 range.")

lib/matplotlib/tests/test_colors.py

+26
Original file line numberDiff line numberDiff line change
@@ -1352,6 +1352,32 @@ def test_scalarmappable_to_rgba(bytes):
13521352
np.testing.assert_almost_equal(sm.to_rgba(xm[..., :3], bytes=bytes), expected)
13531353

13541354

1355+
@pytest.mark.parametrize("bytes", (True, False))
1356+
def test_scalarmappable_nan_to_rgba(bytes):
1357+
sm = cm.ScalarMappable()
1358+
1359+
# RGBA
1360+
x = np.ones((2, 3, 4), dtype=float) * 0.5
1361+
x[0, 0, 0] = np.nan
1362+
expected = x.copy()
1363+
expected[0, 0, :] = 0
1364+
if bytes:
1365+
expected = (expected * 255).astype(np.uint8)
1366+
np.testing.assert_almost_equal(sm.to_rgba(x, bytes=bytes), expected)
1367+
assert np.any(np.isnan(x)) # Input array should not be changed
1368+
1369+
# RGB
1370+
expected[..., 3] = 255 if bytes else 1
1371+
expected[0, 0, 3] = 0
1372+
np.testing.assert_almost_equal(sm.to_rgba(x[..., :3], bytes=bytes), expected)
1373+
assert np.any(np.isnan(x)) # Input array should not be changed
1374+
1375+
# Out-of-range fail
1376+
x[1, 0, 0] = 42
1377+
with pytest.raises(ValueError, match='0..1 range'):
1378+
sm.to_rgba(x[..., :3], bytes=bytes)
1379+
1380+
13551381
def test_failed_conversions():
13561382
with pytest.raises(ValueError):
13571383
mcolors.to_rgba('5')

0 commit comments

Comments
 (0)