diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index f5081825c2ab..b5f66bc33224 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -633,24 +633,20 @@ def __call__(self, X, alpha=None, bytes=False): xa[xa < 0] = self._i_under xa[mask_bad] = self._i_bad + lut = self._lut if bytes: - lut = (self._lut * 255).astype(np.uint8) - else: - lut = self._lut.copy() # Don't let alpha modify original _lut. + lut = (lut * 255).astype(np.uint8) - rgba = np.empty(shape=xa.shape + (4,), dtype=lut.dtype) - lut.take(xa, axis=0, mode='clip', out=rgba) + rgba = lut.take(xa, axis=0, mode='clip') if alpha is not None: - if np.iterable(alpha): - alpha = np.asarray(alpha) - if alpha.shape != xa.shape: - raise ValueError("alpha is array-like but its shape" - " %s doesn't match that of X %s" % - (alpha.shape, xa.shape)) alpha = np.clip(alpha, 0, 1) if bytes: - alpha = (alpha * 255).astype(np.uint8) + alpha *= 255 # Will be cast to uint8 upon assignment. + if alpha.shape not in [(), xa.shape]: + raise ValueError( + f"alpha is array-like but its shape {alpha.shape} does " + f"not match that of X {xa.shape}") rgba[..., -1] = alpha # If the "bad" color is all zeros, then ignore alpha input.