diff --git a/doc/users/next_whats_new/colormap_repr.rst b/doc/users/next_whats_new/colormap_repr.rst
new file mode 100644
index 000000000000..077f6fca12e9
--- /dev/null
+++ b/doc/users/next_whats_new/colormap_repr.rst
@@ -0,0 +1,6 @@
+IPython representations for Colormap objects
+~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
+
+The `matplotlib.colors.Colormap` object now has image representations for
+IPython / Jupyter backends. Cells returning a color map on the last line will
+display an image of the color map.
diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py
index 2471c8a9698c..5cc0be9bc363 100644
--- a/lib/matplotlib/colors.py
+++ b/lib/matplotlib/colors.py
@@ -65,12 +65,17 @@
.. _xkcd color survey: https://xkcd.com/color/rgb/
"""
+import base64
from collections.abc import Sized
import functools
+import io
import itertools
from numbers import Number
import re
+from PIL import Image
+from PIL.PngImagePlugin import PngInfo
+import matplotlib as mpl
import numpy as np
import matplotlib.cbook as cbook
from matplotlib import docstring
@@ -691,6 +696,33 @@ def reversed(self, name=None):
"""
raise NotImplementedError()
+ def _repr_png_(self):
+ """Generate a PNG representation of the Colormap."""
+ IMAGE_SIZE = (400, 50)
+ X = np.tile(np.linspace(0, 1, IMAGE_SIZE[0]), (IMAGE_SIZE[1], 1))
+ pixels = self(X, bytes=True)
+ png_bytes = io.BytesIO()
+ title = self.name + ' color map'
+ author = f'Matplotlib v{mpl.__version__}, https://matplotlib.org'
+ pnginfo = PngInfo()
+ pnginfo.add_text('Title', title)
+ pnginfo.add_text('Description', title)
+ pnginfo.add_text('Author', author)
+ pnginfo.add_text('Software', author)
+ Image.fromarray(pixels).save(png_bytes, format='png', pnginfo=pnginfo)
+ return png_bytes.getvalue()
+
+ def _repr_html_(self):
+ """Generate an HTML representation of the Colormap."""
+ png_bytes = self._repr_png_()
+ png_base64 = base64.b64encode(png_bytes).decode('ascii')
+ return ('' + self.name + '' +
+ '
')
+
class LinearSegmentedColormap(Colormap):
"""
diff --git a/lib/matplotlib/tests/test_colors.py b/lib/matplotlib/tests/test_colors.py
index ba1192d16ac3..00ed7fdfb907 100644
--- a/lib/matplotlib/tests/test_colors.py
+++ b/lib/matplotlib/tests/test_colors.py
@@ -1,7 +1,9 @@
import copy
import itertools
+from io import BytesIO
import numpy as np
+from PIL import Image
import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal
@@ -1135,3 +1137,23 @@ def test_hex_shorthand_notation():
def test_DivergingNorm_deprecated():
with pytest.warns(cbook.MatplotlibDeprecationWarning):
norm = mcolors.DivergingNorm(vcenter=0)
+
+
+def test_repr_png():
+ cmap = plt.get_cmap('viridis')
+ png = cmap._repr_png_()
+ assert len(png) > 0
+ img = Image.open(BytesIO(png))
+ assert img.width > 0
+ assert img.height > 0
+ assert 'Title' in img.text
+ assert 'Description' in img.text
+ assert 'Author' in img.text
+ assert 'Software' in img.text
+
+
+def test_repr_html():
+ cmap = plt.get_cmap('viridis')
+ html = cmap._repr_html_()
+ assert len(html) > 0
+ assert cmap.name in html