Skip to content

Commit 991610f

Browse files
committed
__getitem__ for colors.BivarColormap
Adds support for __getitem__ on colors.BivarColormap, i.e.: BivarColormap[0] and BivarColormap[1], which returns (1D) Colormap objects along the selected axes
1 parent 7d682e7 commit 991610f

File tree

4 files changed

+74
-12
lines changed

4 files changed

+74
-12
lines changed

lib/matplotlib/_cm_bivar.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -1305,8 +1305,8 @@
13051305

13061306
cmaps = {
13071307
"BiPeak": SegmentedBivarColormap(
1308-
BiPeak, "BiPeak", 256, "square"),
1308+
BiPeak, "BiPeak", 256, "square", (128, 128)),
13091309
"BiOrangeBlue": SegmentedBivarColormap(
1310-
BiOrangeBlue, "BiOrangeBlue", 256, "square"),
1311-
"BiCone": SegmentedBivarColormap(BiPeak, "BiCone", 256, "circle"),
1310+
BiOrangeBlue, "BiOrangeBlue", 256, "square", (0, 0)),
1311+
"BiCone": SegmentedBivarColormap(BiPeak, "BiCone", 256, "circle", (128, 128)),
13121312
}

lib/matplotlib/colors.py

+42-5
Original file line numberDiff line numberDiff line change
@@ -1414,7 +1414,7 @@ class BivarColormap(ColormapBase):
14141414
lookup table. To be used with `~matplotlib.cm.VectorMappable`.
14151415
"""
14161416

1417-
def __init__(self, name, N=256, M=256, shape='square'):
1417+
def __init__(self, name, N=256, M=256, shape='square', origin=(0, 0)):
14181418
"""
14191419
Parameters
14201420
----------
@@ -1436,6 +1436,10 @@ def __init__(self, name, N=256, M=256, shape='square'):
14361436
- If 'circleignore' a circular mask is applied, but the data is not
14371437
clipped and instead assigned the 'outside' color
14381438
1439+
origin: (int, int)
1440+
The relative origin of the colormap. Typically (0, 0), for colormaps
1441+
that are linear on both axis, and (int(N*0.5), int(.5*M) for
1442+
circular colormaps.
14391443
"""
14401444

14411445
self.name = name
@@ -1446,6 +1450,7 @@ def __init__(self, name, N=256, M=256, shape='square'):
14461450
self._rgba_outside = (1.0, 0.0, 1.0, 1.0)
14471451
self._isinit = False
14481452
self.n_variates = 2
1453+
self._origin = origin
14491454
'''#: When this colormap exists on a scalar mappable and colorbar_extend
14501455
#: is not False, colorbar creation will pick up ``colorbar_extend`` as
14511456
#: the default value for the ``extend`` keyword in the
@@ -1692,6 +1697,30 @@ def _clip(self, X):
16921697
X[0][mask_outside] = -1
16931698
X[1][mask_outside] = -1
16941699

1700+
def __getitem__(self, item):
1701+
"""Creates and returns a colorbar along the selected axis"""
1702+
if not self._isinit:
1703+
self._init()
1704+
if item == 0:
1705+
cmap = Colormap(self.name+'0', self.N)
1706+
one_d_lut = self._lut[:, self._origin[1]]
1707+
elif item == 1:
1708+
cmap = Colormap(self.name+'1', self.M)
1709+
one_d_lut = self._lut[self._origin[0], :]
1710+
else:
1711+
raise KeyError(f"only 0 or 1 are"
1712+
f" valid keys for BivarColormap, not {item!r}")
1713+
cmap._lut = np.zeros((self.N + 3, 4), float)
1714+
cmap._lut[:-3] = one_d_lut
1715+
cmap.set_bad(self._rgba_bad)
1716+
self._rgba_outside
1717+
if self.shape in ['ignore', 'circleignore']:
1718+
cmap.set_under(self._rgba_outside)
1719+
cmap.set_over(self._rgba_outside)
1720+
cmap._set_extremes()
1721+
cmap._isinit = True
1722+
return cmap
1723+
16951724
def _repr_png_(self):
16961725
"""Generate a PNG representation of the BivarColormap."""
16971726
if not self._isinit:
@@ -1769,11 +1798,15 @@ class SegmentedBivarColormap(BivarColormap):
17691798
'outside' color
17701799
- If 'circleignore' a circular mask is applied, but the data is not clipped
17711800
1801+
origin: (int, int)
1802+
The relative origin of the colormap. Typically (0, 0), for colormaps
1803+
that are linear on both axis, and (int(N*0.5), int(.5*M) for
1804+
circular colormaps.
17721805
"""
17731806

1774-
def __init__(self, patch, name, N=256, shape='square'):
1807+
def __init__(self, patch, name, N=256, shape='square', origin=(0, 0)):
17751808
self.patch = patch
1776-
super().__init__(name, N, N, shape)
1809+
super().__init__(name, N, N, shape, origin)
17771810

17781811
def _init(self):
17791812
s = self.patch.shape
@@ -1809,9 +1842,13 @@ class BivarColormapFromImage(BivarColormap):
18091842
'outside' color
18101843
- If 'circleignore' a circular mask is applied, but the data is not clipped
18111844
1845+
origin: (int, int)
1846+
The relative origin of the colormap. Typically (0, 0), for colormaps
1847+
that are linear on both axis, and (int(N*0.5), int(.5*M) for
1848+
circular colormaps.
18121849
"""
18131850

1814-
def __init__(self, lut, name='', shape='square'):
1851+
def __init__(self, lut, name='', shape='square', origin=(0, 0)):
18151852
# We can allow for a PIL.Image as unput in the following way, but importing
18161853
# matplotlib.image.pil_to_array() results in a circular import
18171854
# For now, this function only accepts numpy arrays.
@@ -1822,7 +1859,7 @@ def __init__(self, lut, name='', shape='square'):
18221859
raise ValueError("The lut must be an array of shape (n, m, 3) or (n, m, 4)",
18231860
" or a PIL.image encoded as RGB or RGBA")
18241861
self._lut = lut
1825-
super().__init__(name, lut.shape[0], lut.shape[1], shape)
1862+
super().__init__(name, lut.shape[0], lut.shape[1], shape, origin)
18261863

18271864
def _init(self):
18281865
self._isinit = True

lib/matplotlib/colors.pyi

+5-3
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ class BivarColormap(ColormapBase):
173173
M: int
174174
shape: str
175175
n_variates: int
176-
def __init__(self, name: str, N: int = ..., M: int | None = ..., shape: str = ...) -> None: ...
176+
def __init__(self, name: str, N: int = ..., M: int | None = ..., shape: str = ..., origin: tuple[int, int] = ...
177+
) -> None: ...
177178
@overload
178179
def __call__(
179180
self, X: Sequence[Sequence[float]] | np.ndarray, alpha: ArrayLike | None = ..., bytes: bool = ...
@@ -198,11 +199,12 @@ class BivarColormap(ColormapBase):
198199

199200
class SegmentedBivarColormap(BivarColormap):
200201
def __init__(
201-
self, patch: np.ndarray, name: str, N: int = ..., shape: str = ...,
202+
self, patch: np.ndarray, name: str, N: int = ..., shape: str = ..., origin: tuple[int, int] = ...
202203
) -> None: ...
203204

204205
class BivarColormapFromImage(BivarColormap):
205-
def __init__(self, lut: np.ndarray, name: str = ..., shape: str = ...) -> None: ...
206+
def __init__(self, lut: np.ndarray, name: str = ..., shape: str = ..., origin: tuple[int, int] = ...
207+
) -> None: ...
206208

207209
class Normalize:
208210
callbacks: cbook.CallbackRegistry

lib/matplotlib/tests/test_multivariate_colormaps.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,30 @@ def test_bivar_cmap_call():
281281
match="only implemented for use with with floats"):
282282
cs = cmap([(0, 5, 9, 0, 0, 9), (0, 0, 0, 5, 11, 11)])
283283

284-
284+
def test_bivar_getitem():
285+
'''Test __getitem__ on BivarColormap'''
286+
xA = ([.0, .25, .5, .75, 1., -1, 2], [.5]*7)
287+
xB = ([.5]*7, [.0, .25, .5, .75, 1., -1, 2])
288+
289+
cmaps = mpl.bivar_colormaps['BiPeak']
290+
assert_array_equal(cmaps(xA), cmaps[0](xA[0]))
291+
assert_array_equal(cmaps(xB), cmaps[1](xB[1]))
292+
293+
cmaps.shape = 'ignore'
294+
assert_array_equal(cmaps(xA), cmaps[0](xA[0]))
295+
assert_array_equal(cmaps(xB), cmaps[1](xB[1]))
296+
297+
xA = ([.0, .25, .5, .75, 1., -1, 2], [.0]*7)
298+
xB = ([.0]*7, [.0, .25, .5, .75, 1., -1, 2])
299+
cmaps = mpl.bivar_colormaps['BiOrangeBlue']
300+
assert_array_equal(cmaps(xA), cmaps[0](xA[0]))
301+
assert_array_equal(cmaps(xB), cmaps[1](xB[1]))
302+
303+
cmaps.shape = 'ignore'
304+
assert_array_equal(cmaps(xA), cmaps[0](xA[0]))
305+
assert_array_equal(cmaps(xB), cmaps[1](xB[1]))
306+
307+
285308
def test_bivar_cmap_bad_shape():
286309
"""
287310
Tests calling a bivariate colormap with integer values

0 commit comments

Comments
 (0)