Skip to content

[NF] Add 'truncate' and 'join' methods to colormaps. #7716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 56 additions & 0 deletions doc/users/next_whats_new/colormap_utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
Colormap Utilities
------------------

Tools for joining, truncating, and resampling colormaps have been added. This grew out of https://gist.github.com/denis-bz/8052855, and http://stackoverflow.com/a/18926541/2121597.


Joining Colormaps
~~~~~~~~~~~~~~~~~

This includes the :func:`~matplotlib.colors.join_colormaps` function::

import matplotlib.pyplat as plt
from matplotlib.colors import join_colormaps

viridis = plt.get_cmap('viridis', 128)
plasma = plt.get_cmap('plasma_r', 64)
jet = plt.get_cmap('jet', 64)

joined_cmap = join_colormaps((viridis, plasma, jet))

This functionality has also been incorporated into the :meth:`~matplotlib.colors.colormap.join` and `~matplotlib.colors.colormap.__add__` methods, so that you can do things like::

plasma_jet = plasma.join(jet)

joined_cmap = viridis + plasma + jet # Same as `join_colormaps` function call above

Truncating and resampling colormaps
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

A :meth:`~matplotlib.colors.colormap.truncate` method has also been added::

sub_viridis = viridis.truncate(0.3, 0.8)

This gives a new colormap that goes from 30% to 80% of viridis. This functionality has also been implemented in the `~matplotlib.colors.colormap.__getitem__` method, so that the same colormap can be created by::

sub_viridis = viridis[0.3:0.8]

The `~matplotlib.colors.colormap.__getitem__` method also supports a range of other 'advanced indexing' options, including integer slice indexing::

sub_viridis2 = viridis[10:90:2]

integer list indexing, which may be particularly useful for creating discrete (low-N) colormaps::

sub_viridis3 = viridis[[4, 35, 59, 90, 110]]

and `numpy.mgrid` style complex indexing::

sub_viridis4 = viridis[0.2:0.4:64j]

See the `~matplotlib.colors.colormap.__getitem__` documentation for more details and examples of how to use these advanced indexing options.

Together, the join and truncate/resample methods allow the user to quickly construct new colormaps from existing ones::

new_cm = viridis[0.5:] + plasma[:0.3] + jet[0.2:0.5:64j]

I doubt this colormap will ever be useful to someone, but hopefully it gives you the idea.
269 changes: 264 additions & 5 deletions lib/matplotlib/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,60 @@ def __delitem__(self, key):
_colors_full_map = _ColorMapping(_colors_full_map)


def join_colormaps(cmaps, fractions=None, name=None, N=None):
"""
Join a series of colormaps into one.

Parameters
----------
cmaps : a sequence of colormaps to be joined (length M)
fractions : a sequence of floats or ints (length M)
The fraction of the new colormap that each colormap should
occupy. These are normalized so they sum to 1. By default, the
fractions are the ``N`` attribute of each cmap.
name : str, optional
The name for the joined colormap. This defaults to
``cmap[0].name + '+' + cmap[1].name + '+' ...``
N : int
The number of entries in the color map. This defaults to the
sum of the ``N`` attributes of the cmaps.

Returns
-------
ListedColormap
The joined colormap.

Examples
--------
import matplotlib.pyplat as plt
cmap1 = plt.get_cmap('viridis', 128)
cmap2 = plt.get_cmap('plasma_r', 64)
cmap3 = plt.get_cmap('jet', 64)

joined_cmap = join_colormaps((cmap1, cmap2, cmap3))

See Also
--------

:meth:`Colorbar.join` and :meth:`Colorbar.__add__` : a method
implementation of this functionality
"""
if N is None:
N = np.sum([cm.N for cm in cmaps])
if fractions is None:
fractions = [cm.N for cm in cmaps]
fractions = np.array(fractions) / np.sum(fractions, dtype='float')
if name is None:
name = ""
for cm in cmaps:
name += cm.name + '+'
name.rstrip('+')
maps = [cm(np.linspace(0, 1, int(round(N * frac))))
for cm, frac in zip(cmaps, fractions)]
# N is set by len of the vstack'd array:
return ListedColormap(np.vstack(maps), name, )


def get_named_colors_mapping():
"""Return the global mapping of names to named colors."""
return _colors_full_map
Expand Down Expand Up @@ -186,8 +240,8 @@ def _to_rgba_no_colorcycle(c, alpha=None):
match = re.match(r"\A#[a-fA-F0-9]{6}\Z", c)
if match:
return (tuple(int(n, 16) / 255
for n in [c[1:3], c[3:5], c[5:7]])
+ (alpha if alpha is not None else 1.,))
for n in [c[1:3], c[3:5], c[5:7]]) +
(alpha if alpha is not None else 1.,))
# hex color with alpha.
match = re.match(r"\A#[a-fA-F0-9]{8}\Z", c)
if match:
Expand Down Expand Up @@ -231,8 +285,8 @@ def to_rgba_array(c, alpha=None):
# Special-case inputs that are already arrays, for performance. (If the
# array has the wrong kind or shape, raise the error during one-at-a-time
# conversion.)
if (isinstance(c, np.ndarray) and c.dtype.kind in "if"
and c.ndim == 2 and c.shape[1] in [3, 4]):
if (isinstance(c, np.ndarray) and c.dtype.kind in "if" and
c.ndim == 2 and c.shape[1] in [3, 4]):
if c.shape[1] == 3:
result = np.column_stack([c, np.zeros(len(c))])
result[:, -1] = alpha if alpha is not None else 1.
Expand Down Expand Up @@ -612,6 +666,211 @@ def reversed(self, name=None):
"""
raise NotImplementedError()

def join(self, other, frac_self=None, name=None, N=None):
"""
Join colormap `self` to `other` and return the new colormap.

Parameters
----------
other : cmap
The other colormap to be joined to this one.
frac_self : float in the interval ``(0.0, 1.0)``, optional
The fraction of the new colormap that should be occupied
by self. By default, this is ``self.N / (self.N +
other.N)``.
name : str, optional
The name for the joined colormap. This defaults to
``self.name + '+' + other.name``
N : int
The number of entries in the color map. The default is ``None``,
in which case the number of entries is the sum of the
number of entries in the two colormaps to be joined.

Returns
-------
ListedColormap
The joined colormap.

Examples
--------
import matplotlib.pyplat as plt
cmap1 = plt.get_cmap('viridis', 128)
cmap2 = plt.get_cmap('plasma_r', 64)

joined_cmap = cmap1.join(cmap2)

# Note that `joined_cmap` will be 2/3 `cmap1`, and 1/3 `cmap2`
# because of the default behavior of frac_self
# (i.e. proportional to N of each cmap).

# This is also available as :meth:`Colormap.__add__`, so that
# the following is possible:

joined_cmap = cmap1 + cmap2
"""
if frac_self is None:
frac_self = self.N / (other.N + self.N)
fractions = [frac_self, 1 - frac_self]
return join_colormaps([self, other], fractions, name, N)

__add__ = join

def truncate(self, minval=0.0, maxval=1.0, name=None, N=None):
"""
Truncate a colormap.

Parameters
----------
minval : float in the interval ``(0.0, 1.0)``, optional
The lower fraction of the colormap you want to truncate
(default 0.0).
maxval : float in the interval ``(0.0, 1.0)``, optional
The upper limit of the colormap you want to keep. i.e. truncate
the section above this value (default 1.0).
name : str, optional
The name for the new truncated colormap. This defaults to
``"trunc({},{:.2f},{:.2f})".format(self.name, minval, maxval)``
N : int
The number of entries in the map. The default is *None*,
in which case the same color-step density is preserved,
i.e.: N = ceil(N * (maxval - minval))

Returns
-------
ListedColormap
The truncated colormap.

Examples
--------
import matplotlib.pyplat as plt
cmap = plt.get_cmap('viridis')

# This will return the `viridis` colormap with the bottom 20%,
# and top 30% removed:
cmap_trunc = cmap.truncate(0.2, 0.7)

"""
if minval >= maxval:
raise ValueError("minval must be less than maxval")
if minval < 0 or minval >= 1 or maxval <= 0 or maxval > 1:
raise ValueError(
"minval and maxval must be in the interval (0.0, 1.0)"
)
if minval == 0 and maxval == 1:
raise ValueError("This is not a truncation")
# This was taken largely from
# https://gist.github.com/denis-bz/8052855
# Which, in turn was from @unutbu's SO answer:
# http://stackoverflow.com/a/18926541/2121597
if N is None:
N = np.ceil(self.N * (maxval - minval))
if name is None:
name = "trunc({},{:.2f},{:.2f})".format(self.name, minval, maxval)
return ListedColormap(self(np.linspace(minval, maxval, N)), name)

def __getitem__(self, item):
"""Advanced indexing for colorbars.

Examples
--------
import matplotlib.pyplat as plt
cmap = plt.get_cmap('viridis', 128)

# ### float indexing
# for float-style indexing, the values must be in [0.0, 1.0]
# Truncate the colormap between 20 and 80%.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo? 80->60?

new_cm = cmap[0.2:0.6]
# `new_cm` will have the color-spacing as `cmap` (in this
# case: 0.6 - 0.2 = 40% of 128 = 51 colors)

# negative values are supported
# this gives the same result as above
new_cm = cmap[0.2:-0.4]

# Same as above, but specify the number of points
# using `np.mgrid` complex-indexing:
new_cm = cmap[0.2:-0.4:64j]

# ### Int-style indexing
# for int-style indexing, the values must be in [0, self.N]
new_cm = cmap[12:100]

# Same as above, but 4x fewer points
new_cm = cmap[12:100:4]

# negative values are supported (same as above)
new_cm = cmap[12:-28:4]

# And so is `np.mgrid` complex-indexing (same as above)
new_cm = cmap[12:-28:22j]

# ### Array/list-style indexing
# In this case, you specify specific points in the colormap
# at which you'd like to create a new colormap.

# You can index by integers, in which case
# all values must be ints in [-self.N, self.N]:
new_cm = cmap[[5, 10, 25, -38]]

# Or by floats in the range [-1, 1]
new_cm = cmap[[0.04, 0.08, 0.2, -0.3]]
"""
if isinstance(item, slice):
sss = [item.start, item.stop, item.step]
name = self.name + '[{}:{}:{}]'.format(*sss)
if (all([s is None or abs(s) <= 1 for s in sss[:2]]) and
(sss[2] is None or abs(sss[2]) <= 1 or
isinstance(sss[2], complex))):
if sss[0] is None:
sss[0] = 0
elif sss[0] < 0:
sss[0] += 1
if sss[1] is None:
sss[1] = 1.0
elif sss[1] < 0:
sss[1] += 1
if sss[2] is None:
sss[2] = self.N * 1j * (sss[1] - sss[0])
elif all([s is None or (s % 1 == 0) for s in sss[:2]]):
# This is an integer-style itemization
if sss[0] is None:
sss[0] = 0
elif sss[0] < 0:
sss[0] = sss[0] % self.N
if sss[1] is None:
sss[1] = self.N
elif sss[1] < 0:
sss[1] = sss[1] % self.N
if sss[2] is None:
sss[2] = 1
sss[0] = sss[0] / self.N
sss[1] = sss[1] / self.N
if not isinstance(sss[2], complex):
sss[2] = sss[2] / self.N
if sss[0] < 0 or sss[0] >= 1 or sss[1] <= 0 or sss[1] > 1:
raise IndexError("Invalid colorbar itemization - outside "
"bounds")
else:
raise IndexError("Invalid colorbar itemization")
points = np.mgrid[slice(*sss)]
elif isinstance(item, (list, np.ndarray)):
name = self.name + '[<indexed>]'
if isinstance(item, list):
item = np.array(item)
if item.dtype.kind in ('u', 'i'):
item = item.astype('f') / self.N
item[item < 0] += 1
if np.any(item > 1):
raise IndexError("Invalid colorbar itemization - outside "
"bounds")
points = item
else:
raise IndexError("Invalid colorbar itemization")
if len(points) <= 1:
raise IndexError("Invalid colorbar itemization - a colorbar must "
"contain >1 color.")
return ListedColormap(self(points), name=name)


class LinearSegmentedColormap(Colormap):
"""Colormap objects based on lookup tables using linear segments.
Expand Down Expand Up @@ -1059,7 +1318,7 @@ class SymLogNorm(Normalize):
*linthresh* allows the user to specify the size of this range
(-*linthresh*, *linthresh*).
"""
def __init__(self, linthresh, linscale=1.0,
def __init__(self, linthresh, linscale=1.0,
vmin=None, vmax=None, clip=False):
"""
*linthresh*:
Expand Down
Loading