diff --git a/doc/users/next_whats_new/colormap_utils.rst b/doc/users/next_whats_new/colormap_utils.rst new file mode 100644 index 000000000000..87ed5d6379a7 --- /dev/null +++ b/doc/users/next_whats_new/colormap_utils.rst @@ -0,0 +1,56 @@ +Colormap Utilities +------------------ + +Tools for joining, truncating, and resampling colormaps have been added. This grew out of https://gist.github.com/denis-bz/8052855, and http://stackoverflow.com/a/18926541/2121597. + + +Joining Colormaps +~~~~~~~~~~~~~~~~~ + +This includes the :func:`~matplotlib.colors.join_colormaps` function:: + + import matplotlib.pyplat as plt + from matplotlib.colors import join_colormaps + + viridis = plt.get_cmap('viridis', 128) + plasma = plt.get_cmap('plasma_r', 64) + jet = plt.get_cmap('jet', 64) + + joined_cmap = join_colormaps((viridis, plasma, jet)) + +This functionality has also been incorporated into the :meth:`~matplotlib.colors.colormap.join` and `~matplotlib.colors.colormap.__add__` methods, so that you can do things like:: + + plasma_jet = plasma.join(jet) + + joined_cmap = viridis + plasma + jet # Same as `join_colormaps` function call above + +Truncating and resampling colormaps +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +A :meth:`~matplotlib.colors.colormap.truncate` method has also been added:: + + sub_viridis = viridis.truncate(0.3, 0.8) + +This gives a new colormap that goes from 30% to 80% of viridis. This functionality has also been implemented in the `~matplotlib.colors.colormap.__getitem__` method, so that the same colormap can be created by:: + + sub_viridis = viridis[0.3:0.8] + +The `~matplotlib.colors.colormap.__getitem__` method also supports a range of other 'advanced indexing' options, including integer slice indexing:: + + sub_viridis2 = viridis[10:90:2] + +integer list indexing, which may be particularly useful for creating discrete (low-N) colormaps:: + + sub_viridis3 = viridis[[4, 35, 59, 90, 110]] + +and `numpy.mgrid` style complex indexing:: + + sub_viridis4 = viridis[0.2:0.4:64j] + +See the `~matplotlib.colors.colormap.__getitem__` documentation for more details and examples of how to use these advanced indexing options. + +Together, the join and truncate/resample methods allow the user to quickly construct new colormaps from existing ones:: + + new_cm = viridis[0.5:] + plasma[:0.3] + jet[0.2:0.5:64j] + +I doubt this colormap will ever be useful to someone, but hopefully it gives you the idea. diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index 82969ed18cb7..49f815136d9b 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -82,6 +82,60 @@ def __delitem__(self, key): _colors_full_map = _ColorMapping(_colors_full_map) +def join_colormaps(cmaps, fractions=None, name=None, N=None): + """ + Join a series of colormaps into one. + + Parameters + ---------- + cmaps : a sequence of colormaps to be joined (length M) + fractions : a sequence of floats or ints (length M) + The fraction of the new colormap that each colormap should + occupy. These are normalized so they sum to 1. By default, the + fractions are the ``N`` attribute of each cmap. + name : str, optional + The name for the joined colormap. This defaults to + ``cmap[0].name + '+' + cmap[1].name + '+' ...`` + N : int + The number of entries in the color map. This defaults to the + sum of the ``N`` attributes of the cmaps. + + Returns + ------- + ListedColormap + The joined colormap. + + Examples + -------- + import matplotlib.pyplat as plt + cmap1 = plt.get_cmap('viridis', 128) + cmap2 = plt.get_cmap('plasma_r', 64) + cmap3 = plt.get_cmap('jet', 64) + + joined_cmap = join_colormaps((cmap1, cmap2, cmap3)) + + See Also + -------- + + :meth:`Colorbar.join` and :meth:`Colorbar.__add__` : a method + implementation of this functionality + """ + if N is None: + N = np.sum([cm.N for cm in cmaps]) + if fractions is None: + fractions = [cm.N for cm in cmaps] + fractions = np.array(fractions) / np.sum(fractions, dtype='float') + if name is None: + name = "" + for cm in cmaps: + name += cm.name + '+' + name.rstrip('+') + maps = [cm(np.linspace(0, 1, int(round(N * frac)))) + for cm, frac in zip(cmaps, fractions)] + # N is set by len of the vstack'd array: + return ListedColormap(np.vstack(maps), name, ) + + def get_named_colors_mapping(): """Return the global mapping of names to named colors.""" return _colors_full_map @@ -186,8 +240,8 @@ def _to_rgba_no_colorcycle(c, alpha=None): match = re.match(r"\A#[a-fA-F0-9]{6}\Z", c) if match: return (tuple(int(n, 16) / 255 - for n in [c[1:3], c[3:5], c[5:7]]) - + (alpha if alpha is not None else 1.,)) + for n in [c[1:3], c[3:5], c[5:7]]) + + (alpha if alpha is not None else 1.,)) # hex color with alpha. match = re.match(r"\A#[a-fA-F0-9]{8}\Z", c) if match: @@ -231,8 +285,8 @@ def to_rgba_array(c, alpha=None): # Special-case inputs that are already arrays, for performance. (If the # array has the wrong kind or shape, raise the error during one-at-a-time # conversion.) - if (isinstance(c, np.ndarray) and c.dtype.kind in "if" - and c.ndim == 2 and c.shape[1] in [3, 4]): + if (isinstance(c, np.ndarray) and c.dtype.kind in "if" and + c.ndim == 2 and c.shape[1] in [3, 4]): if c.shape[1] == 3: result = np.column_stack([c, np.zeros(len(c))]) result[:, -1] = alpha if alpha is not None else 1. @@ -612,6 +666,211 @@ def reversed(self, name=None): """ raise NotImplementedError() + def join(self, other, frac_self=None, name=None, N=None): + """ + Join colormap `self` to `other` and return the new colormap. + + Parameters + ---------- + other : cmap + The other colormap to be joined to this one. + frac_self : float in the interval ``(0.0, 1.0)``, optional + The fraction of the new colormap that should be occupied + by self. By default, this is ``self.N / (self.N + + other.N)``. + name : str, optional + The name for the joined colormap. This defaults to + ``self.name + '+' + other.name`` + N : int + The number of entries in the color map. The default is ``None``, + in which case the number of entries is the sum of the + number of entries in the two colormaps to be joined. + + Returns + ------- + ListedColormap + The joined colormap. + + Examples + -------- + import matplotlib.pyplat as plt + cmap1 = plt.get_cmap('viridis', 128) + cmap2 = plt.get_cmap('plasma_r', 64) + + joined_cmap = cmap1.join(cmap2) + + # Note that `joined_cmap` will be 2/3 `cmap1`, and 1/3 `cmap2` + # because of the default behavior of frac_self + # (i.e. proportional to N of each cmap). + + # This is also available as :meth:`Colormap.__add__`, so that + # the following is possible: + + joined_cmap = cmap1 + cmap2 + """ + if frac_self is None: + frac_self = self.N / (other.N + self.N) + fractions = [frac_self, 1 - frac_self] + return join_colormaps([self, other], fractions, name, N) + + __add__ = join + + def truncate(self, minval=0.0, maxval=1.0, name=None, N=None): + """ + Truncate a colormap. + + Parameters + ---------- + minval : float in the interval ``(0.0, 1.0)``, optional + The lower fraction of the colormap you want to truncate + (default 0.0). + maxval : float in the interval ``(0.0, 1.0)``, optional + The upper limit of the colormap you want to keep. i.e. truncate + the section above this value (default 1.0). + name : str, optional + The name for the new truncated colormap. This defaults to + ``"trunc({},{:.2f},{:.2f})".format(self.name, minval, maxval)`` + N : int + The number of entries in the map. The default is *None*, + in which case the same color-step density is preserved, + i.e.: N = ceil(N * (maxval - minval)) + + Returns + ------- + ListedColormap + The truncated colormap. + + Examples + -------- + import matplotlib.pyplat as plt + cmap = plt.get_cmap('viridis') + + # This will return the `viridis` colormap with the bottom 20%, + # and top 30% removed: + cmap_trunc = cmap.truncate(0.2, 0.7) + + """ + if minval >= maxval: + raise ValueError("minval must be less than maxval") + if minval < 0 or minval >= 1 or maxval <= 0 or maxval > 1: + raise ValueError( + "minval and maxval must be in the interval (0.0, 1.0)" + ) + if minval == 0 and maxval == 1: + raise ValueError("This is not a truncation") + # This was taken largely from + # https://gist.github.com/denis-bz/8052855 + # Which, in turn was from @unutbu's SO answer: + # http://stackoverflow.com/a/18926541/2121597 + if N is None: + N = np.ceil(self.N * (maxval - minval)) + if name is None: + name = "trunc({},{:.2f},{:.2f})".format(self.name, minval, maxval) + return ListedColormap(self(np.linspace(minval, maxval, N)), name) + + def __getitem__(self, item): + """Advanced indexing for colorbars. + + Examples + -------- + import matplotlib.pyplat as plt + cmap = plt.get_cmap('viridis', 128) + + # ### float indexing + # for float-style indexing, the values must be in [0.0, 1.0] + # Truncate the colormap between 20 and 80%. + new_cm = cmap[0.2:0.6] + # `new_cm` will have the color-spacing as `cmap` (in this + # case: 0.6 - 0.2 = 40% of 128 = 51 colors) + + # negative values are supported + # this gives the same result as above + new_cm = cmap[0.2:-0.4] + + # Same as above, but specify the number of points + # using `np.mgrid` complex-indexing: + new_cm = cmap[0.2:-0.4:64j] + + # ### Int-style indexing + # for int-style indexing, the values must be in [0, self.N] + new_cm = cmap[12:100] + + # Same as above, but 4x fewer points + new_cm = cmap[12:100:4] + + # negative values are supported (same as above) + new_cm = cmap[12:-28:4] + + # And so is `np.mgrid` complex-indexing (same as above) + new_cm = cmap[12:-28:22j] + + # ### Array/list-style indexing + # In this case, you specify specific points in the colormap + # at which you'd like to create a new colormap. + + # You can index by integers, in which case + # all values must be ints in [-self.N, self.N]: + new_cm = cmap[[5, 10, 25, -38]] + + # Or by floats in the range [-1, 1] + new_cm = cmap[[0.04, 0.08, 0.2, -0.3]] + """ + if isinstance(item, slice): + sss = [item.start, item.stop, item.step] + name = self.name + '[{}:{}:{}]'.format(*sss) + if (all([s is None or abs(s) <= 1 for s in sss[:2]]) and + (sss[2] is None or abs(sss[2]) <= 1 or + isinstance(sss[2], complex))): + if sss[0] is None: + sss[0] = 0 + elif sss[0] < 0: + sss[0] += 1 + if sss[1] is None: + sss[1] = 1.0 + elif sss[1] < 0: + sss[1] += 1 + if sss[2] is None: + sss[2] = self.N * 1j * (sss[1] - sss[0]) + elif all([s is None or (s % 1 == 0) for s in sss[:2]]): + # This is an integer-style itemization + if sss[0] is None: + sss[0] = 0 + elif sss[0] < 0: + sss[0] = sss[0] % self.N + if sss[1] is None: + sss[1] = self.N + elif sss[1] < 0: + sss[1] = sss[1] % self.N + if sss[2] is None: + sss[2] = 1 + sss[0] = sss[0] / self.N + sss[1] = sss[1] / self.N + if not isinstance(sss[2], complex): + sss[2] = sss[2] / self.N + if sss[0] < 0 or sss[0] >= 1 or sss[1] <= 0 or sss[1] > 1: + raise IndexError("Invalid colorbar itemization - outside " + "bounds") + else: + raise IndexError("Invalid colorbar itemization") + points = np.mgrid[slice(*sss)] + elif isinstance(item, (list, np.ndarray)): + name = self.name + '[]' + if isinstance(item, list): + item = np.array(item) + if item.dtype.kind in ('u', 'i'): + item = item.astype('f') / self.N + item[item < 0] += 1 + if np.any(item > 1): + raise IndexError("Invalid colorbar itemization - outside " + "bounds") + points = item + else: + raise IndexError("Invalid colorbar itemization") + if len(points) <= 1: + raise IndexError("Invalid colorbar itemization - a colorbar must " + "contain >1 color.") + return ListedColormap(self(points), name=name) + class LinearSegmentedColormap(Colormap): """Colormap objects based on lookup tables using linear segments. @@ -1059,7 +1318,7 @@ class SymLogNorm(Normalize): *linthresh* allows the user to specify the size of this range (-*linthresh*, *linthresh*). """ - def __init__(self, linthresh, linscale=1.0, + def __init__(self, linthresh, linscale=1.0, vmin=None, vmax=None, clip=False): """ *linthresh*: diff --git a/lib/matplotlib/tests/test_colorbar.py b/lib/matplotlib/tests/test_colorbar.py index 2230c20e7d2f..6b0eac61da9e 100644 --- a/lib/matplotlib/tests/test_colorbar.py +++ b/lib/matplotlib/tests/test_colorbar.py @@ -4,7 +4,7 @@ from matplotlib import rc_context from matplotlib.testing.decorators import image_comparison import matplotlib.pyplot as plt -from matplotlib.colors import BoundaryNorm, LogNorm, PowerNorm +from matplotlib.colors import BoundaryNorm, LogNorm, PowerNorm, join_colormaps from matplotlib.cm import get_cmap from matplotlib.colorbar import ColorbarBase @@ -188,6 +188,165 @@ def test_gridspec_make_colorbar(): plt.subplots_adjust(top=0.95, right=0.95, bottom=0.2, hspace=0.25) +def test_join_colorbar(): + test_points = [0.1, 0.3, 0.9] + + # Jet is a LinearSegmentedColormap + cmap1 = plt.get_cmap('viridis', 5) + cmap2 = plt.get_cmap('jet', 5) + + # This should be a listed colormap. + cmap = cmap1.join(cmap2) + vals = cmap(test_points) + _vals = np.array( + [[0.229739, 0.322361, 0.545706, 1.], + [0.369214, 0.788888, 0.382914, 1.], + [0.5, 0., 0, 1.]] + ) + assert np.allclose(vals, _vals) + + # Use the 'frac_self' kwarg for the listed cmap + cmap = cmap1.join(cmap2, frac_self=0.7, N=50) + vals = cmap(test_points) + _vals = np.array( + [[0.267004, 0.004874, 0.329415, 1.], + [0.127568, 0.566949, 0.550556, 1.], + [1., 0.59259259, 0., 1.]] + ) + assert np.allclose(vals, _vals) + + # +code-coverage for name kwarg and when fractions is unspecified + cmap = join_colormaps([cmap1, cmap2, cmap1], name='test-map') + vals = cmap(test_points) + _vals = np.array( + [[0.229739, 0.322361, 0.545706, 1., ], + [0.993248, 0.906157, 0.143936, 1., ], + [0.369214, 0.788888, 0.382914, 1., ]] + ) + assert np.allclose(vals, _vals) + + +def test_truncate_colorbar(): + test_points = [0.1, 0.3, 0.9] + vir32 = plt.get_cmap('viridis', 32) + vir128 = plt.get_cmap('viridis', 128) + + cmap = vir32.truncate(0.2, 0.7) + vals = cmap(test_points) + _vals = np.array( + [[0.243113, 0.292092, 0.538516, 1.], + [0.19586, 0.395433, 0.555276, 1.], + [0.226397, 0.728888, 0.462789, 1.]] + ) + assert np.allclose(vals, _vals) + + # +code-coverage: N and name kwargs + cmap = vir32.truncate(0.2, 0.7, name='test-map', N=128) + vals = cmap(test_points) + _vals = np.array( + [[0.243113, 0.292092, 0.538516, 1., ], + [0.182256, 0.426184, 0.55712, 1., ], + [0.180653, 0.701402, 0.488189, 1., ]] + ) + assert np.allclose(vals, _vals) + + # Use __getitem__ fractional complex slicing with start:None + cmap = vir128[:-0.3:16j] + vals = cmap(test_points) + _vals = np.array( + [[0.278791, 0.062145, 0.386592, 1., ], + [0.262138, 0.242286, 0.520837, 1., ], + [0.19109, 0.708366, 0.482284, 1., ]] + ) + assert np.allclose(vals, _vals) + + # Use __getitem__ fractional slicing start:negative, end:None + cmap = vir128[-0.9:] + vals = cmap(test_points) + _vals = np.array( + [[0.262138, 0.242286, 0.520837, 1., ], + [0.175841, 0.44129, 0.557685, 1., ], + [0.772852, 0.877868, 0.131109, 1., ]] + ) + assert np.allclose(vals, _vals) + + # Use __getitem__ integer slicing + cmap = vir128[25:90] + vals = cmap(test_points) + _vals = np.array( + [[0.233603, 0.313828, 0.543914, 1.], + [0.185556, 0.41857, 0.556753, 1.], + [0.19109, 0.708366, 0.482284, 1.]] + ) + assert np.allclose(vals, _vals) + + # start:None, end:negative, integer jumping + cmap = vir128[:-10:4] + vals = cmap(test_points) + _vals = np.array( + [[0.282884, 0.13592, 0.453427, 1., ], + [0.214298, 0.355619, 0.551184, 1., ], + [0.606045, 0.850733, 0.236712, 1., ]] + ) + assert np.allclose(vals, _vals) + + # Use __getitem__ integer complex slicing + cmap = vir128[-100::16j] + vals = cmap(test_points) + _vals = np.array( + [[0.221989, 0.339161, 0.548752, 1., ], + [0.154815, 0.493313, 0.55784, 1., ], + [0.876168, 0.891125, 0.09525, 1., ]] + ) + assert np.allclose(vals, _vals) + + # Use __getitem__ discrete slicing + cmap = vir128[[10, 12, 15, 35, 60, 97]] + vals = cmap(test_points) + _vals = np.array( + [[0.283197, 0.11568, 0.436115, 1., ], + [0.282884, 0.13592, 0.453427, 1., ], + [0.395174, 0.797475, 0.367757, 1., ]] + ) + assert np.allclose(vals, _vals) + + +def test_truncate_colorbar_fail(): + vir128 = plt.get_cmap('viridis', 128) + + with pytest.raises(ValueError, match='less than'): + vir128.truncate(0.7, 0.3) + + with pytest.raises(ValueError, match='not a truncation'): + vir128.truncate(0, 1) + + with pytest.raises(ValueError, match='interval'): + vir128.truncate(0.3, 1.1) + + with pytest.raises(ValueError, match='interval'): + vir128.truncate(-0.1, 0.7) + + with pytest.raises(IndexError, match='must contain >1 color'): + vir128[[3]] + + with pytest.raises(IndexError, match='Invalid colorbar itemization'): + # Tuple indexing of colorbars not allowed. + vir128[3, 5, 9] + + with pytest.raises(IndexError, match='Invalid colorbar itemization'): + # Currently you can't mix-match fractional and int style indexing. + # This could be changed... + vir128[0.3:100] + + with pytest.raises(IndexError, match='Invalid colorbar itemization'): + # 150 is beyond the 128-bit colormap. + vir128[[10, 100, 150]] + + with pytest.raises(IndexError, match='Invalid colorbar itemization'): + # The first index can't be the end. + vir128[128:] + + @image_comparison(baseline_images=['colorbar_single_scatter'], extensions=['png'], remove_text=True, savefig_kwarg={'dpi': 40})