diff --git a/doc/users/next_whats_new/colormap_repr.rst b/doc/users/next_whats_new/colormap_repr.rst new file mode 100644 index 000000000000..077f6fca12e9 --- /dev/null +++ b/doc/users/next_whats_new/colormap_repr.rst @@ -0,0 +1,6 @@ +IPython representations for Colormap objects +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +The `matplotlib.colors.Colormap` object now has image representations for +IPython / Jupyter backends. Cells returning a color map on the last line will +display an image of the color map. diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index 2471c8a9698c..5cc0be9bc363 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -65,12 +65,17 @@ .. _xkcd color survey: https://xkcd.com/color/rgb/ """ +import base64 from collections.abc import Sized import functools +import io import itertools from numbers import Number import re +from PIL import Image +from PIL.PngImagePlugin import PngInfo +import matplotlib as mpl import numpy as np import matplotlib.cbook as cbook from matplotlib import docstring @@ -691,6 +696,33 @@ def reversed(self, name=None): """ raise NotImplementedError() + def _repr_png_(self): + """Generate a PNG representation of the Colormap.""" + IMAGE_SIZE = (400, 50) + X = np.tile(np.linspace(0, 1, IMAGE_SIZE[0]), (IMAGE_SIZE[1], 1)) + pixels = self(X, bytes=True) + png_bytes = io.BytesIO() + title = self.name + ' color map' + author = f'Matplotlib v{mpl.__version__}, https://matplotlib.org' + pnginfo = PngInfo() + pnginfo.add_text('Title', title) + pnginfo.add_text('Description', title) + pnginfo.add_text('Author', author) + pnginfo.add_text('Software', author) + Image.fromarray(pixels).save(png_bytes, format='png', pnginfo=pnginfo) + return png_bytes.getvalue() + + def _repr_html_(self): + """Generate an HTML representation of the Colormap.""" + png_bytes = self._repr_png_() + png_base64 = base64.b64encode(png_bytes).decode('ascii') + return ('' + self.name + '' + + '') + class LinearSegmentedColormap(Colormap): """ diff --git a/lib/matplotlib/tests/test_colors.py b/lib/matplotlib/tests/test_colors.py index ba1192d16ac3..00ed7fdfb907 100644 --- a/lib/matplotlib/tests/test_colors.py +++ b/lib/matplotlib/tests/test_colors.py @@ -1,7 +1,9 @@ import copy import itertools +from io import BytesIO import numpy as np +from PIL import Image import pytest from numpy.testing import assert_array_equal, assert_array_almost_equal @@ -1135,3 +1137,23 @@ def test_hex_shorthand_notation(): def test_DivergingNorm_deprecated(): with pytest.warns(cbook.MatplotlibDeprecationWarning): norm = mcolors.DivergingNorm(vcenter=0) + + +def test_repr_png(): + cmap = plt.get_cmap('viridis') + png = cmap._repr_png_() + assert len(png) > 0 + img = Image.open(BytesIO(png)) + assert img.width > 0 + assert img.height > 0 + assert 'Title' in img.text + assert 'Description' in img.text + assert 'Author' in img.text + assert 'Software' in img.text + + +def test_repr_html(): + cmap = plt.get_cmap('viridis') + html = cmap._repr_html_() + assert len(html) > 0 + assert cmap.name in html