@@ -633,24 +633,20 @@ def __call__(self, X, alpha=None, bytes=False):
633
633
xa [xa < 0 ] = self ._i_under
634
634
xa [mask_bad ] = self ._i_bad
635
635
636
+ lut = self ._lut
636
637
if bytes :
637
- lut = (self ._lut * 255 ).astype (np .uint8 )
638
- else :
639
- lut = self ._lut .copy () # Don't let alpha modify original _lut.
638
+ lut = (lut * 255 ).astype (np .uint8 )
640
639
641
- rgba = np .empty (shape = xa .shape + (4 ,), dtype = lut .dtype )
642
- lut .take (xa , axis = 0 , mode = 'clip' , out = rgba )
640
+ rgba = lut .take (xa , axis = 0 , mode = 'clip' )
643
641
644
642
if alpha is not None :
645
- if np .iterable (alpha ):
646
- alpha = np .asarray (alpha )
647
- if alpha .shape != xa .shape :
648
- raise ValueError ("alpha is array-like but its shape"
649
- " %s doesn't match that of X %s" %
650
- (alpha .shape , xa .shape ))
651
643
alpha = np .clip (alpha , 0 , 1 )
652
644
if bytes :
653
- alpha = (alpha * 255 ).astype (np .uint8 )
645
+ alpha *= 255 # Will be cast to uint8 upon assignment.
646
+ if alpha .shape not in [(), xa .shape ]:
647
+ raise ValueError (
648
+ f"alpha is array-like but its shape { alpha .shape } does "
649
+ f"not match that of X { xa .shape } " )
654
650
rgba [..., - 1 ] = alpha
655
651
656
652
# If the "bad" color is all zeros, then ignore alpha input.
0 commit comments