Skip to content

Commit 8b99d37

Browse files
committed
MultiNorm class
1 parent f2717a5 commit 8b99d37

File tree

2 files changed

+273
-0
lines changed

2 files changed

+273
-0
lines changed

lib/matplotlib/colors.py

Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2188,6 +2188,16 @@ def __init__(self, vmin=None, vmax=None, clip=False):
21882188
self._scale = None
21892189
self.callbacks = cbook.CallbackRegistry(signals=["changed"])
21902190

2191+
@property
2192+
def n_input(self):
2193+
# To be overridden by subclasses with multiple inputs
2194+
return 1
2195+
2196+
@property
2197+
def n_output(self):
2198+
# To be overridden by subclasses with multiple outputs
2199+
return 1
2200+
21912201
@property
21922202
def vmin(self):
21932203
return self._vmin
@@ -3087,6 +3097,237 @@ def inverse(self, value):
30873097
return value
30883098

30893099

3100+
class MultiNorm(Normalize):
3101+
"""
3102+
A mixin class which contains multiple scalar norms
3103+
"""
3104+
3105+
def __init__(self, norms, vmin=None, vmax=None, clip=False):
3106+
"""
3107+
Parameters
3108+
----------
3109+
norms : List of strings or `Normalize` objects
3110+
The constituent norms. The list must have a minimum length of 2.
3111+
vmin, vmax : float, None, or list of float or None
3112+
Limits of the constituent norms.
3113+
If a list, each each value is assigned to one of the constituent
3114+
norms. Single values are repeated to form a list of appropriate size.
3115+
3116+
clip : bool or list of bools, default: False
3117+
Determines the behavior for mapping values outside the range
3118+
``[vmin, vmax]`` for the constituent norms.
3119+
If a list, each each value is assigned to one of the constituent
3120+
norms. Single values are repeated to form a list of appropriate size.
3121+
3122+
"""
3123+
3124+
if isinstance(norms, str) or not np.iterable(norms):
3125+
raise ValueError("A MultiNorm must be assigned multiple norms")
3126+
norms = [n for n in norms]
3127+
for i, n in enumerate(norms):
3128+
if n is None:
3129+
norms[i] = Normalize()
3130+
elif isinstance(n, str):
3131+
try:
3132+
scale_cls = scale._scale_mapping[n]
3133+
except KeyError:
3134+
raise ValueError(
3135+
"Invalid norm str name; the following values are "
3136+
f"supported: {', '.join(scale._scale_mapping)}"
3137+
) from None
3138+
norms[i] = mpl.colorizer._auto_norm_from_scale(scale_cls)()
3139+
3140+
# Convert the list of norms to a tuple to make it immutable.
3141+
# If there is a use case for swapping a single norm, we can add support for
3142+
# that later
3143+
self._norms = tuple(n for n in norms)
3144+
3145+
self.callbacks = cbook.CallbackRegistry(signals=["changed"])
3146+
3147+
self.vmin = vmin
3148+
self.vmax = vmax
3149+
self.clip = clip
3150+
3151+
self._id_norms = [n.callbacks.connect('changed',
3152+
self._changed) for n in self._norms]
3153+
3154+
@property
3155+
def n_input(self):
3156+
return len(self._norms)
3157+
3158+
@property
3159+
def n_output(self):
3160+
return len(self._norms)
3161+
3162+
@property
3163+
def norms(self):
3164+
return self._norms
3165+
3166+
@property
3167+
def vmin(self):
3168+
return tuple(n.vmin for n in self._norms)
3169+
3170+
@vmin.setter
3171+
def vmin(self, value):
3172+
if not np.iterable(value):
3173+
value = [value]*self.n_input
3174+
if len(value) != self.n_input:
3175+
raise ValueError(f"Invalid vmin for `MultiNorm` with {self.n_input}"
3176+
" inputs.")
3177+
with self.callbacks.blocked():
3178+
for i, v in enumerate(value):
3179+
if v is not None:
3180+
self.norms[i].vmin = v
3181+
self._changed()
3182+
3183+
@property
3184+
def vmax(self):
3185+
return tuple(n.vmax for n in self._norms)
3186+
3187+
@vmax.setter
3188+
def vmax(self, value):
3189+
if not np.iterable(value):
3190+
value = [value]*self.n_input
3191+
if len(value) != self.n_input:
3192+
raise ValueError(f"Invalid vmax for `MultiNorm` with {self.n_input}"
3193+
" inputs.")
3194+
with self.callbacks.blocked():
3195+
for i, v in enumerate(value):
3196+
if v is not None:
3197+
self.norms[i].vmax = v
3198+
self._changed()
3199+
3200+
@property
3201+
def clip(self):
3202+
return tuple(n.clip for n in self._norms)
3203+
3204+
@clip.setter
3205+
def clip(self, value):
3206+
if not np.iterable(value):
3207+
value = [value]*self.n_input
3208+
with self.callbacks.blocked():
3209+
for i, v in enumerate(value):
3210+
if v is not None:
3211+
self.norms[i].clip = v
3212+
self._changed()
3213+
3214+
def _changed(self):
3215+
"""
3216+
Call this whenever the norm is changed to notify all the
3217+
callback listeners to the 'changed' signal.
3218+
"""
3219+
self.callbacks.process('changed')
3220+
3221+
def __call__(self, value, clip=None):
3222+
"""
3223+
Normalize the data and return the normalized data.
3224+
Each variate in the input is assigned to the a constituent norm.
3225+
3226+
Parameters
3227+
----------
3228+
value
3229+
Data to normalize. Must be of length `n_input` or have a data type with
3230+
`n_input` fields.
3231+
clip : List of bools or bool, optional
3232+
See the description of the parameter *clip* in Normalize.
3233+
If ``None``, defaults to ``self.clip`` (which defaults to
3234+
``False``).
3235+
3236+
Returns
3237+
-------
3238+
Data
3239+
Normalized input values as a list of length `n_input`
3240+
3241+
Notes
3242+
-----
3243+
If not already initialized, ``self.vmin`` and ``self.vmax`` are
3244+
initialized using ``self.autoscale_None(value)``.
3245+
"""
3246+
if clip is None:
3247+
clip = self.clip
3248+
else:
3249+
if not np.iterable(clip):
3250+
value = [value]*self.n_input
3251+
3252+
value = self._iterable_variates_in_data(value, self.n_input)
3253+
result = [n(v, clip=c) for n, v, c in zip(self.norms, value, clip)]
3254+
return result
3255+
3256+
def inverse(self, value):
3257+
"""
3258+
Maps the normalized value (i.e., index in the colormap) back to image
3259+
data value.
3260+
3261+
Parameters
3262+
----------
3263+
value
3264+
Normalized value. Must be of length `n_input` or have a data type with
3265+
`n_input` fields.
3266+
"""
3267+
value = self._iterable_variates_in_data(value, self.n_input)
3268+
result = [n.inverse(v) for n, v in zip(self.norms, value)]
3269+
return result
3270+
3271+
def autoscale(self, A):
3272+
"""
3273+
For each constituent norm, Set *vmin*, *vmax* to min, max of the corresponding
3274+
variate in *A*.
3275+
"""
3276+
with self.callbacks.blocked():
3277+
# Pause callbacks while we are updating so we only get
3278+
# a single update signal at the end
3279+
self.vmin = self.vmax = None
3280+
self.autoscale_None(A)
3281+
3282+
def autoscale_None(self, A):
3283+
"""
3284+
If *vmin* or *vmax* are not set on any constituent norm,
3285+
use the min/max of the corresponding variate in *A* to set them.
3286+
3287+
Parameters
3288+
----------
3289+
A
3290+
Data, must be of length `n_input` or be an np.ndarray type with
3291+
`n_input` fields.
3292+
"""
3293+
with self.callbacks.blocked():
3294+
A = self._iterable_variates_in_data(A, self.n_input)
3295+
for n, a in zip(self.norms, A):
3296+
n.autoscale_None(a)
3297+
self._changed()
3298+
3299+
def scaled(self):
3300+
"""Return whether both *vmin* and *vmax* are set on all constitient norms"""
3301+
return all([(n.vmin is not None and n.vmax is not None) for n in self.norms])
3302+
3303+
@staticmethod
3304+
def _iterable_variates_in_data(data, n_input):
3305+
"""
3306+
Provides an iterable over the variates contained in the data.
3307+
3308+
An input array with n_input fields is returned as a list of length n referencing
3309+
slices of the original array.
3310+
3311+
Parameters
3312+
----------
3313+
data : np.ndarray, tuple or list
3314+
The input array. It must either be an array with n_input fields or have
3315+
a length (n_input)
3316+
3317+
Returns
3318+
-------
3319+
list of np.ndarray
3320+
3321+
"""
3322+
if isinstance(data, np.ndarray) and data.dtype.fields is not None:
3323+
data = [data[descriptor[0]] for descriptor in data.dtype.descr]
3324+
if not len(data) == n_input:
3325+
raise ValueError("The input to this `MultiNorm` must be of shape "
3326+
f"({n_input}, ...), or have a data type with {n_input} "
3327+
"fields.")
3328+
return data
3329+
3330+
30903331
def rgb_to_hsv(arr):
30913332
"""
30923333
Convert an array of float RGB values (in the range [0, 1]) to HSV values.

lib/matplotlib/colors.pyi

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,10 @@ class Normalize:
248248
@vmax.setter
249249
def vmax(self, value: float | None) -> None: ...
250250
@property
251+
def n_input(self) -> int: ...
252+
@property
253+
def n_output(self) -> int: ...
254+
@property
251255
def clip(self) -> bool: ...
252256
@clip.setter
253257
def clip(self, value: bool) -> None: ...
@@ -372,6 +376,34 @@ class BoundaryNorm(Normalize):
372376

373377
class NoNorm(Normalize): ...
374378

379+
class MultiNorm(Normalize):
380+
# Here "type: ignore[override]" is used for functions with a return type
381+
# that differs from the function in the base class.
382+
# i.e. where `MultiNorm` returns a tuple and Normalize returns a `float` etc.
383+
def __init__(
384+
self,
385+
norms: ArrayLike,
386+
vmin: ArrayLike | float | None = ...,
387+
vmax: ArrayLike | float | None = ...,
388+
clip: ArrayLike | bool = ...
389+
) -> None: ...
390+
@property
391+
def norms(self) -> tuple: ...
392+
@property # type: ignore[override]
393+
def vmin(self) -> tuple[float | None]: ...
394+
@vmin.setter
395+
def vmin(self, value: ArrayLike | float | None) -> None: ...
396+
@property # type: ignore[override]
397+
def vmax(self) -> tuple[float | None]: ...
398+
@vmax.setter
399+
def vmax(self, value: ArrayLike | float | None) -> None: ...
400+
@property # type: ignore[override]
401+
def clip(self) -> tuple[bool]: ...
402+
@clip.setter
403+
def clip(self, value: ArrayLike | bool) -> None: ...
404+
def __call__(self, value: ArrayLike, clip: ArrayLike | bool | None) -> list: ... # type: ignore[override]
405+
def inverse(self, value: ArrayLike) -> list: ... # type: ignore[override]
406+
375407
def rgb_to_hsv(arr: ArrayLike) -> np.ndarray: ...
376408
def hsv_to_rgb(hsv: ArrayLike) -> np.ndarray: ...
377409

0 commit comments

Comments
 (0)