Skip to content

Commit 03e7111

Browse files
committed
Refactor onto a 'join_colormaps' function
1 parent 64a5fa8 commit 03e7111

File tree

1 file changed

+56
-10
lines changed

1 file changed

+56
-10
lines changed

lib/matplotlib/colors.py

Lines changed: 56 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,60 @@ def __delitem__(self, key, value):
9999
_colors_full_map = _ColorMapping(_colors_full_map)
100100

101101

102+
def join_colormaps(cmaps, fractions=None, name=None, N=None):
103+
"""
104+
Join a series of colormaps into one.
105+
106+
Parameters
107+
----------
108+
cmaps : a sequence of colormaps to be joined (length M)
109+
fractions : a sequence of floats or ints (length M)
110+
The fraction of the new colormap that each colormap should
111+
occupy. These are normalized so they sum to 1. By default, the
112+
fractions are the ``N`` attribute of each cmap.
113+
name : str, optional
114+
The name for the joined colormap. This defaults to
115+
``cmap[0].name + '+' + cmap[1].name + '+' ...``
116+
N : int
117+
The number of entries in the color map. This defaults to the
118+
sum of the ``N`` attributes of the cmaps.
119+
120+
Returns
121+
-------
122+
ListedColormap
123+
The joined colormap.
124+
125+
Examples
126+
--------
127+
import matplotlib.pyplat as plt
128+
cmap1 = plt.get_cmap('viridis', 128)
129+
cmap2 = plt.get_cmap('plasma_r', 64)
130+
cmap3 = plt.get_cmap('jet', 64)
131+
132+
joined_cmap = join_colormaps((cmap1, cmap2, cmap3))
133+
134+
See Also
135+
--------
136+
137+
:meth:`Colorbar.join` and :meth:`Colorbar.__add__` : a method
138+
implementation of this functionality
139+
"""
140+
if N is None:
141+
N = np.sum([cm.N for cm in cmaps])
142+
if fractions is None:
143+
fractions = [cm.N for cm in cmaps]
144+
fractions = np.array(fractions) / np.sum(fractions, dtype='float')
145+
if name is None:
146+
name = ""
147+
for cm in cmaps:
148+
name += cm.name + '+'
149+
name.rstrip('+')
150+
maps = [cm(np.linspace(0, 1, int(N * frac)))
151+
for cm, frac in zip(cmaps, fractions)]
152+
# N is set by len of the vstack'd array:
153+
return ListedColormap(np.vstack(maps), name, )
154+
155+
102156
def get_named_colors_mapping():
103157
"""Return the global mapping of names to named colors.
104158
"""
@@ -641,18 +695,10 @@ def join(self, other, frac_self=None, name=None, N=None):
641695
642696
joined_cmap = cmap1 + cmap2
643697
"""
644-
if N is None:
645-
N = self.N + other.N
646698
if frac_self is None:
647699
frac_self = self.N / (other.N + self.N)
648-
if name is None:
649-
name = '{}+{}'.format(self.name, other.name)
650-
if not (0 < frac_self and frac_self < 1):
651-
raise ValueError("frac_self must be in the interval (0.0, 1.0)")
652-
map0 = self(np.linspace(0, 1, int(N * frac_self)))
653-
map1 = other(np.linspace(0, 1, int(N * (1 - frac_self))))
654-
# N is set by len of the vstack'd array:
655-
return ListedColormap(np.vstack((map0, map1)), name, )
700+
fractions = [frac_self, 1 - frac_self]
701+
return join_colormaps([self, other], fractions, name, N)
656702

657703
__add__ = join
658704

0 commit comments

Comments
 (0)