Skip to content

Commit 3a4fe48

Browse files
committed
Set norms using scale names.
1 parent abfaa7a commit 3a4fe48

File tree

4 files changed

+73
-19
lines changed

4 files changed

+73
-19
lines changed

lib/matplotlib/cm.py

Lines changed: 44 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
"""
1717

1818
from collections.abc import Mapping, MutableMapping
19+
import functools
1920

2021
import numpy as np
2122
from numpy import ma
2223

2324
import matplotlib as mpl
24-
from matplotlib import _api, colors, cbook
25+
from matplotlib import _api, colors, cbook, scale
2526
from matplotlib._cm import datad
2627
from matplotlib._cm_listed import cmaps as cmaps_listed
2728

@@ -331,6 +332,35 @@ def unregister_cmap(name):
331332
return _cmap_registry.pop(name)
332333

333334

335+
@functools.lru_cache(None)
336+
def _auto_norm_from_scale(scale_cls):
337+
"""
338+
Automatically generate a norm class from *scale_cls*.
339+
340+
This differs from `.colors.make_norm_from_scale` in the following points:
341+
342+
- This function is not a class decorator, but directly returns a norm class
343+
(as if decorating `.Normalize`).
344+
- The scale is automatically constructed with ``nonpositive="mask"``, if it
345+
supports such a parameter, to work around the difference in defaults
346+
between standard scales (which use "clip") and norms (which use "mask").
347+
- The returned norm class is memoized and reused for later calls.
348+
(`.colors.make_norm_from_scale` also memoizes, but we call it with a
349+
`functools.partial` instances which always compare unequal, so the cache
350+
doesn't get hit.)
351+
"""
352+
# Actually try to construct an instance, to verify whether
353+
# ``nonpositive="mask"`` is supported.
354+
try:
355+
norm = colors.make_norm_from_scale(
356+
functools.partial(scale_cls, nonpositive="mask"))(
357+
colors.Normalize)()
358+
except TypeError:
359+
norm = colors.make_norm_from_scale(scale_cls)(
360+
colors.Normalize)()
361+
return type(norm)
362+
363+
334364
class ScalarMappable:
335365
"""
336366
A mixin class to map scalar data to RGBA.
@@ -341,12 +371,13 @@ class ScalarMappable:
341371

342372
def __init__(self, norm=None, cmap=None):
343373
"""
344-
345374
Parameters
346375
----------
347-
norm : `matplotlib.colors.Normalize` (or subclass thereof)
376+
norm : `.Normalize` (or subclass thereof) or str or None
348377
The normalizing object which scales data, typically into the
349378
interval ``[0, 1]``.
379+
If a `str`, a `.Normalize` subclass is dynamically generated based
380+
on the scale with the corresponding name.
350381
If *None*, *norm* defaults to a *colors.Normalize* object which
351382
initializes its scaling based on the first data processed.
352383
cmap : str or `~matplotlib.colors.Colormap`
@@ -376,11 +407,11 @@ def _scale_norm(self, norm, vmin, vmax):
376407
"""
377408
if vmin is not None or vmax is not None:
378409
self.set_clim(vmin, vmax)
379-
if norm is not None:
410+
if isinstance(norm, colors.Normalize):
380411
raise ValueError(
381-
"Passing parameters norm and vmin/vmax simultaneously is "
382-
"not supported. Please pass vmin/vmax directly to the "
383-
"norm when creating it.")
412+
"Passing a Normalize instance simultaneously with "
413+
"vmin/vmax is not supported. Please pass vmin/vmax "
414+
"directly to the norm when creating it.")
384415

385416
# always resolve the autoscaling so we have concrete limits
386417
# rather than deferring to draw time.
@@ -554,9 +585,13 @@ def norm(self):
554585

555586
@norm.setter
556587
def norm(self, norm):
557-
_api.check_isinstance((colors.Normalize, None), norm=norm)
588+
_api.check_isinstance((colors.Normalize, str, None), norm=norm)
558589
if norm is None:
559590
norm = colors.Normalize()
591+
elif isinstance(norm, str):
592+
# case-insensitive, consistently with scale_factory.
593+
scale_cls = scale._scale_mapping[norm.lower()]
594+
norm = _auto_norm_from_scale(scale_cls)()
560595

561596
if norm is self.norm:
562597
# We aren't updating anything
@@ -578,7 +613,7 @@ def set_norm(self, norm):
578613
579614
Parameters
580615
----------
581-
norm : `.Normalize` or None
616+
norm : `.Normalize` or str or None
582617
583618
Notes
584619
-----

lib/matplotlib/colors.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1576,12 +1576,15 @@ def inverse(self, value):
15761576
.reshape(np.shape(value)))
15771577
return value[0] if is_scalar else value
15781578

1579-
Norm.__name__ = (
1580-
f"{scale_cls.__name__}Norm" if base_norm_cls is Normalize
1581-
else base_norm_cls.__name__)
1582-
Norm.__qualname__ = (
1583-
f"{scale_cls.__qualname__}Norm" if base_norm_cls is Normalize
1584-
else base_norm_cls.__qualname__)
1579+
if base_norm_cls is Normalize:
1580+
name_source = (scale_cls.func
1581+
if isinstance(scale_cls, functools.partial)
1582+
else scale_cls)
1583+
Norm.__name__ = f"{name_source.__name__}Norm"
1584+
Norm.__qualname__ = f"{name_source.__qualname__}Norm"
1585+
else:
1586+
Norm.__name__ = base_norm_cls.__name__
1587+
Norm.__qualname__ = base_norm_cls.__qualname__
15851588
Norm.__module__ = base_norm_cls.__module__
15861589
Norm.__doc__ = base_norm_cls.__doc__
15871590

lib/matplotlib/tests/test_axes.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -901,8 +901,8 @@ def test_imshow_norm_vminvmax():
901901
a = [[1, 2], [3, 4]]
902902
ax = plt.axes()
903903
with pytest.raises(ValueError,
904-
match="Passing parameters norm and vmin/vmax "
905-
"simultaneously is not supported."):
904+
match="Passing a Normalize instance simultaneously "
905+
"with vmin/vmax is not supported."):
906906
ax.imshow(a, norm=mcolors.Normalize(-10, 10), vmin=0, vmax=5)
907907

908908

@@ -2263,8 +2263,8 @@ def test_scatter_norm_vminvmax(self):
22632263
x = [1, 2, 3]
22642264
ax = plt.axes()
22652265
with pytest.raises(ValueError,
2266-
match="Passing parameters norm and vmin/vmax "
2267-
"simultaneously is not supported."):
2266+
match="Passing a Normalize instance simultaneously "
2267+
"with vmin/vmax is not supported."):
22682268
ax.scatter(x, x, c=x, norm=mcolors.Normalize(-10, 10),
22692269
vmin=0, vmax=5)
22702270

lib/matplotlib/tests/test_image.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1376,3 +1376,19 @@ def test_rgba_antialias():
13761376
# alternating red and blue stripes become purple
13771377
axs[3].imshow(aa, interpolation='antialiased', interpolation_stage='rgba',
13781378
cmap=cmap, vmin=-1.2, vmax=1.2)
1379+
1380+
1381+
@check_figures_equal(extensions=["png"])
1382+
def test_str_norms(fig_test, fig_ref):
1383+
t = np.random.rand(10, 10) * .8 + .1 # between 0 and 1
1384+
axs = fig_test.subplots(1, 4)
1385+
axs[0].imshow(t, norm="log")
1386+
axs[1].imshow(t, norm="log", vmin=.2)
1387+
axs[2].imshow(t, norm="symlog")
1388+
axs[3].imshow(t, norm="symlog", vmin=.3, vmax=.7)
1389+
axs = fig_ref.subplots(1, 4)
1390+
axs[0].imshow(t, norm=colors.LogNorm())
1391+
axs[1].imshow(t, norm=colors.LogNorm(vmin=.2))
1392+
# same linthresh as SymmetricalLogScale's default.
1393+
axs[2].imshow(t, norm=colors.SymLogNorm(linthresh=2))
1394+
axs[3].imshow(t, norm=colors.SymLogNorm(linthresh=2, vmin=.3, vmax=.7))

0 commit comments

Comments
 (0)