Skip to content

Properly handle transparency for animations #5415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
4 changes: 2 additions & 2 deletions lib/matplotlib/animation.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def grab_frame(self, **savefig_kwargs):
# Tell the figure to save its data to the sink, using the
# frame format and dpi.
self.fig.savefig(self._frame_sink(), format=self.frame_format,
dpi=self.dpi, **savefig_kwargs)
dpi=self.dpi, flatten=True, **savefig_kwargs)
except RuntimeError:
out, err = self._proc.communicate()
verbose.report('MovieWriter -- Error '
Expand Down Expand Up @@ -358,7 +358,7 @@ def grab_frame(self, **savefig_kwargs):
# frame format and dpi.
myframesink = self._frame_sink()
self.fig.savefig(myframesink, format=self.frame_format,
dpi=self.dpi, **savefig_kwargs)
dpi=self.dpi, flatten=True, **savefig_kwargs)
myframesink.close()

except RuntimeError:
Expand Down
41 changes: 28 additions & 13 deletions lib/matplotlib/backends/backend_agg.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from matplotlib.path import Path
from matplotlib.transforms import Bbox, BboxBase
from matplotlib import colors as mcolors
from matplotlib import image as mimage

from matplotlib.backends._backend_agg import RendererAgg as _RendererAgg
from matplotlib import _png
Expand Down Expand Up @@ -508,8 +509,16 @@ def print_raw(self, filename_or_obj, *args, **kwargs):
else:
fileobj = filename_or_obj
close = False
img = renderer._renderer.buffer_rgba()

# Flatten RGBA if used with fileformat that doesn't handle trnasparency
if kwargs.get('flatten', False):
w, h = int(renderer.width), int(renderer.height)
img = np.array(memoryview(img)).reshape((h, w, 4))
img = mimage.flatten_rgba(img)

try:
fileobj.write(renderer._renderer.buffer_rgba())
fileobj.write(img)
finally:
if close:
filename_or_obj.close()
Expand All @@ -527,8 +536,16 @@ def print_png(self, filename_or_obj, *args, **kwargs):
else:
close = False

# Flatten RGBA if used as intermediate fileformat for something
# that doesn't support transparency (ie: Animations)
img = renderer._renderer
if kwargs.get('flatten', False):
img = img.buffer_rgba()
w, h = int(renderer.width), int(renderer.height)
img = np.array(memoryview(img)).reshape((h, w, 4))
img = mimage.flatten_rgba(img)
try:
_png.write_png(renderer._renderer, filename_or_obj, self.figure.dpi)
_png.write_png(img, filename_or_obj, self.figure.dpi)
finally:
if close:
filename_or_obj.close()
Expand Down Expand Up @@ -563,24 +580,22 @@ def print_jpg(self, filename_or_obj, *args, **kwargs):
*progressive*: If present, indicates that this image
should be stored as a progressive JPEG file.
"""
buf, size = self.print_to_buffer()
buf, (w, h) = self.print_to_buffer()
buf = np.array(memoryview(buf)).reshape((h, w, 4))

if kwargs.pop("dryrun", False):
return
# The image is "pasted" onto a white background image to safely
# handle any transparency
image = Image.frombuffer('RGBA', size, buf, 'raw', 'RGBA', 0, 1)
color = mcolors.colorConverter.to_rgb(
rcParams.get('savefig.facecolor', 'white'))
color = tuple([int(x * 255.0) for x in color])
background = Image.new('RGB', size, color)
background.paste(image, image)

# Flatten RGBA image to safely handle transparent regions
buf = mimage.flatten_rgba(buf)
img = Image.frombuffer('RGBA', (w, h), buf, 'raw', 'RGBA', 0, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Confusingly, the (w, h) I think is correct here because this is for PIL, not Numpy.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I think the images looked correct even though I decoded the stream as a (w,h) Numpy array because the pixels were still being matched correctly for flattening. Why does matplotlib use HxW for serializing its images? Isn't WxH a commonly accepted standard?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The images are in row-major order (i.e. as English text is read), and Numpy dimensions are in the order of "slowest moving" to "fastest moving" dimension, thus (h, w, 4). It's a Numpy/array programming convention. PIL, as an image library, not an array library, uses the more common image convention (w, h).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think several days of my life have gone to sorting out (w, h) vs (r, c) issues!

This is one of the major selling points of xray.


options = restrict_dict(kwargs, ['quality', 'optimize',
'progressive'])

if 'quality' not in options:
options['quality'] = rcParams['savefig.jpeg_quality']

return background.save(filename_or_obj, format='jpeg', **options)
return img.save(filename_or_obj, format='jpeg', **options)
print_jpeg = print_jpg

# add TIFF support
Expand Down
43 changes: 43 additions & 0 deletions lib/matplotlib/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -1502,3 +1502,46 @@ def thumbnail(infile, thumbfile, scale=0.1, interpolation='bilinear',
ax.imshow(im, aspect='auto', resample=True, interpolation=interpolation)
fig.savefig(thumbfile, dpi=dpi)
return fig


def flatten_rgba(src, bg=None):
"""
Flatten an RGBA image *src* with a background color *bg*.
The resulting image will have an alpha channel, but no transparency.
This can be useful when interfacing with file formats that don't support
transparency or only support boolean transparency.

Parameters
----------
src : MxNx4 Numpy array, dtype=uint8
Image source in RGBA to be flattened.

bg : Tuple(int,int,int), optional
Background color to merge *src* with. If no bg color is provided
the color from the rcParam 'savefig.facecolor' will be used.

Returns
-------
dest : MxNx4 Numpy array, dtype=uint8
"""

if bg is None:
bg = mcolors.colorConverter.to_rgb(
rcParams.get('savefig.facecolor', 'white'))
bg = tuple([int(x * 255.0) for x in bg])

# Numpy images have dtype=uint8 which will overflow for these calculations
src = src.astype(np.uint16)

alpha = src[:, :, 3]
src_rgb = src[:, :, :3]
w, h, _ = src.shape

dest = np.empty((w, h, 4))
dest[:, :, 0] = (255 - alpha)*bg[0] + alpha*src_rgb[:, :, 0]
dest[:, :, 1] = (255 - alpha)*bg[1] + alpha*src_rgb[:, :, 1]
dest[:, :, 2] = (255 - alpha)*bg[2] + alpha*src_rgb[:, :, 2]
dest = (dest/255).astype(np.uint8)
dest[:, :, 3] = 255

return dest
42 changes: 32 additions & 10 deletions lib/matplotlib/tests/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ def test_nonuniformimage_setnorm():

@knownfailureif(not HAS_PIL)
@cleanup
def test_jpeg_alpha():
def test_flatten():
plt.figure(figsize=(1, 1), dpi=300)
# Create an image that is all black, with a gradient from 0-1 in
# the alpha channel from left to right.
Expand All @@ -464,21 +464,43 @@ def test_jpeg_alpha():

plt.figimage(im)

buff = io.BytesIO()
jpg_buf = io.BytesIO()
pngF_buf = io.BytesIO()
png_buf = io.BytesIO()

with rc_context({'savefig.facecolor': 'red'}):
plt.savefig(buff, transparent=True, format='jpg', dpi=300)
plt.savefig(jpg_buf, transparent=True, format='jpg', dpi=300)
plt.savefig(pngF_buf, transparent=True, format='png',
flatten=True, dpi=300)
plt.savefig(png_buf, transparent=True, format='png', dpi=300)

buff.seek(0)
image = Image.open(buff)
jpg_buf.seek(0)
pngF_buf.seek(0)
png_buf.seek(0)

jpg_im = Image.open(jpg_buf)
pngF_im = Image.open(pngF_buf)
png_im = Image.open(png_buf)

# If this fails, there will be only one color (all black). If this
# is working, we should have all 256 shades of grey represented.
print("num colors: ", len(image.getcolors(256)))
assert len(image.getcolors(256)) >= 175 and len(image.getcolors(256)) <= 185
print("num colors [jpg]: ", len(jpg_im.getcolors(256)))
print("num colors [png, flattened]: ", len(pngF_im.getcolors(256)))
print("num colors [png, not flattened]: ", len(png_im.getcolors(256)))

assert len(jpg_im.getcolors(256)) >= 175 and len(jpg_im.getcolors(256)) <= 185
assert len(pngF_im.getcolors(256)) == 256
assert len(png_im.getcolors(256)) == 256

# The fully transparent part should be red, not white or black
# or anything else
print("corner pixel: ", image.getpixel((0, 0)))
assert image.getpixel((0, 0)) == (254, 0, 0)
# or anything else when flattened.
print("corner pixel [jpg]: ", jpg_im.getpixel((0, 0)))
print("corner pixel [png, flattened]: ", pngF_im.getpixel((0,0)))
print("corner pixel [png, not flattened]: ", png_im.getpixel((0,0)))

assert jpg_im.getpixel((0, 0)) == (254, 0, 0)
assert pngF_im.getpixel((0,0)) == (255, 0, 0, 255)
assert png_im.getpixel((0,0)) == (255, 255, 255, 0)


if __name__=='__main__':
Expand Down