diff --git a/lib/matplotlib/animation.py b/lib/matplotlib/animation.py index 6a62384d1c54..ba61ccdc5ce2 100644 --- a/lib/matplotlib/animation.py +++ b/lib/matplotlib/animation.py @@ -222,7 +222,7 @@ def grab_frame(self, **savefig_kwargs): # Tell the figure to save its data to the sink, using the # frame format and dpi. self.fig.savefig(self._frame_sink(), format=self.frame_format, - dpi=self.dpi, **savefig_kwargs) + dpi=self.dpi, flatten=True, **savefig_kwargs) except RuntimeError: out, err = self._proc.communicate() verbose.report('MovieWriter -- Error ' @@ -358,7 +358,7 @@ def grab_frame(self, **savefig_kwargs): # frame format and dpi. myframesink = self._frame_sink() self.fig.savefig(myframesink, format=self.frame_format, - dpi=self.dpi, **savefig_kwargs) + dpi=self.dpi, flatten=True, **savefig_kwargs) myframesink.close() except RuntimeError: diff --git a/lib/matplotlib/backends/backend_agg.py b/lib/matplotlib/backends/backend_agg.py index 5dea52363e86..56299e4447ba 100644 --- a/lib/matplotlib/backends/backend_agg.py +++ b/lib/matplotlib/backends/backend_agg.py @@ -39,6 +39,7 @@ from matplotlib.path import Path from matplotlib.transforms import Bbox, BboxBase from matplotlib import colors as mcolors +from matplotlib import image as mimage from matplotlib.backends._backend_agg import RendererAgg as _RendererAgg from matplotlib import _png @@ -508,8 +509,16 @@ def print_raw(self, filename_or_obj, *args, **kwargs): else: fileobj = filename_or_obj close = False + img = renderer._renderer.buffer_rgba() + + # Flatten RGBA if used with fileformat that doesn't handle trnasparency + if kwargs.get('flatten', False): + w, h = int(renderer.width), int(renderer.height) + img = np.array(memoryview(img)).reshape((h, w, 4)) + img = mimage.flatten_rgba(img) + try: - fileobj.write(renderer._renderer.buffer_rgba()) + fileobj.write(img) finally: if close: filename_or_obj.close() @@ -527,8 +536,16 @@ def print_png(self, filename_or_obj, *args, **kwargs): else: close = False + # Flatten RGBA if used as intermediate fileformat for something + # that doesn't support transparency (ie: Animations) + img = renderer._renderer + if kwargs.get('flatten', False): + img = img.buffer_rgba() + w, h = int(renderer.width), int(renderer.height) + img = np.array(memoryview(img)).reshape((h, w, 4)) + img = mimage.flatten_rgba(img) try: - _png.write_png(renderer._renderer, filename_or_obj, self.figure.dpi) + _png.write_png(img, filename_or_obj, self.figure.dpi) finally: if close: filename_or_obj.close() @@ -563,24 +580,22 @@ def print_jpg(self, filename_or_obj, *args, **kwargs): *progressive*: If present, indicates that this image should be stored as a progressive JPEG file. """ - buf, size = self.print_to_buffer() + buf, (w, h) = self.print_to_buffer() + buf = np.array(memoryview(buf)).reshape((h, w, 4)) + if kwargs.pop("dryrun", False): return - # The image is "pasted" onto a white background image to safely - # handle any transparency - image = Image.frombuffer('RGBA', size, buf, 'raw', 'RGBA', 0, 1) - color = mcolors.colorConverter.to_rgb( - rcParams.get('savefig.facecolor', 'white')) - color = tuple([int(x * 255.0) for x in color]) - background = Image.new('RGB', size, color) - background.paste(image, image) + + # Flatten RGBA image to safely handle transparent regions + buf = mimage.flatten_rgba(buf) + img = Image.frombuffer('RGBA', (w, h), buf, 'raw', 'RGBA', 0, 1) + options = restrict_dict(kwargs, ['quality', 'optimize', 'progressive']) - if 'quality' not in options: options['quality'] = rcParams['savefig.jpeg_quality'] - return background.save(filename_or_obj, format='jpeg', **options) + return img.save(filename_or_obj, format='jpeg', **options) print_jpeg = print_jpg # add TIFF support diff --git a/lib/matplotlib/image.py b/lib/matplotlib/image.py index 7a3cd1dc9c39..4ae8cd1409f0 100644 --- a/lib/matplotlib/image.py +++ b/lib/matplotlib/image.py @@ -1502,3 +1502,46 @@ def thumbnail(infile, thumbfile, scale=0.1, interpolation='bilinear', ax.imshow(im, aspect='auto', resample=True, interpolation=interpolation) fig.savefig(thumbfile, dpi=dpi) return fig + + +def flatten_rgba(src, bg=None): + """ + Flatten an RGBA image *src* with a background color *bg*. + The resulting image will have an alpha channel, but no transparency. + This can be useful when interfacing with file formats that don't support + transparency or only support boolean transparency. + + Parameters + ---------- + src : MxNx4 Numpy array, dtype=uint8 + Image source in RGBA to be flattened. + + bg : Tuple(int,int,int), optional + Background color to merge *src* with. If no bg color is provided + the color from the rcParam 'savefig.facecolor' will be used. + + Returns + ------- + dest : MxNx4 Numpy array, dtype=uint8 + """ + + if bg is None: + bg = mcolors.colorConverter.to_rgb( + rcParams.get('savefig.facecolor', 'white')) + bg = tuple([int(x * 255.0) for x in bg]) + + # Numpy images have dtype=uint8 which will overflow for these calculations + src = src.astype(np.uint16) + + alpha = src[:, :, 3] + src_rgb = src[:, :, :3] + w, h, _ = src.shape + + dest = np.empty((w, h, 4)) + dest[:, :, 0] = (255 - alpha)*bg[0] + alpha*src_rgb[:, :, 0] + dest[:, :, 1] = (255 - alpha)*bg[1] + alpha*src_rgb[:, :, 1] + dest[:, :, 2] = (255 - alpha)*bg[2] + alpha*src_rgb[:, :, 2] + dest = (dest/255).astype(np.uint8) + dest[:, :, 3] = 255 + + return dest diff --git a/lib/matplotlib/tests/test_image.py b/lib/matplotlib/tests/test_image.py index fde7c7cc18c8..ce54aff1780e 100644 --- a/lib/matplotlib/tests/test_image.py +++ b/lib/matplotlib/tests/test_image.py @@ -455,7 +455,7 @@ def test_nonuniformimage_setnorm(): @knownfailureif(not HAS_PIL) @cleanup -def test_jpeg_alpha(): +def test_flatten(): plt.figure(figsize=(1, 1), dpi=300) # Create an image that is all black, with a gradient from 0-1 in # the alpha channel from left to right. @@ -464,21 +464,43 @@ def test_jpeg_alpha(): plt.figimage(im) - buff = io.BytesIO() + jpg_buf = io.BytesIO() + pngF_buf = io.BytesIO() + png_buf = io.BytesIO() + with rc_context({'savefig.facecolor': 'red'}): - plt.savefig(buff, transparent=True, format='jpg', dpi=300) + plt.savefig(jpg_buf, transparent=True, format='jpg', dpi=300) + plt.savefig(pngF_buf, transparent=True, format='png', + flatten=True, dpi=300) + plt.savefig(png_buf, transparent=True, format='png', dpi=300) - buff.seek(0) - image = Image.open(buff) + jpg_buf.seek(0) + pngF_buf.seek(0) + png_buf.seek(0) + + jpg_im = Image.open(jpg_buf) + pngF_im = Image.open(pngF_buf) + png_im = Image.open(png_buf) # If this fails, there will be only one color (all black). If this # is working, we should have all 256 shades of grey represented. - print("num colors: ", len(image.getcolors(256))) - assert len(image.getcolors(256)) >= 175 and len(image.getcolors(256)) <= 185 + print("num colors [jpg]: ", len(jpg_im.getcolors(256))) + print("num colors [png, flattened]: ", len(pngF_im.getcolors(256))) + print("num colors [png, not flattened]: ", len(png_im.getcolors(256))) + + assert len(jpg_im.getcolors(256)) >= 175 and len(jpg_im.getcolors(256)) <= 185 + assert len(pngF_im.getcolors(256)) == 256 + assert len(png_im.getcolors(256)) == 256 + # The fully transparent part should be red, not white or black - # or anything else - print("corner pixel: ", image.getpixel((0, 0))) - assert image.getpixel((0, 0)) == (254, 0, 0) + # or anything else when flattened. + print("corner pixel [jpg]: ", jpg_im.getpixel((0, 0))) + print("corner pixel [png, flattened]: ", pngF_im.getpixel((0,0))) + print("corner pixel [png, not flattened]: ", png_im.getpixel((0,0))) + + assert jpg_im.getpixel((0, 0)) == (254, 0, 0) + assert pngF_im.getpixel((0,0)) == (255, 0, 0, 255) + assert png_im.getpixel((0,0)) == (255, 255, 255, 0) if __name__=='__main__':