Skip to content

Commit d3ed88a

Browse files
committed
Fix picklability of make_norm_from_scale norms.
And also fix their qualname, which was missed in a00a909.
1 parent 1535cdc commit d3ed88a

File tree

3 files changed

+42
-8
lines changed

3 files changed

+42
-8
lines changed

lib/matplotlib/colors.py

+34-7
Original file line numberDiff line numberDiff line change
@@ -1507,9 +1507,26 @@ class norm_cls(Normalize):
15071507

15081508
if init is None:
15091509
def init(vmin=None, vmax=None, clip=False): pass
1510-
bound_init_signature = inspect.signature(init)
1510+
1511+
return _make_norm_from_scale(
1512+
scale_cls, base_norm_cls, inspect.signature(init))
1513+
1514+
1515+
@functools.lru_cache(None)
1516+
def _make_norm_from_scale(scale_cls, base_norm_cls, bound_init_signature):
1517+
"""
1518+
Helper for `make_norm_from_scale`.
1519+
1520+
This function is split out so that it takes a signature object as third
1521+
argument (as signatures are picklable, contrary to arbitrary lambdas);
1522+
caching is also used so that different unpickles reuse the same class.
1523+
"""
15111524

15121525
class Norm(base_norm_cls):
1526+
def __reduce__(self):
1527+
return (_picklable_norm_constructor,
1528+
(scale_cls, base_norm_cls, bound_init_signature),
1529+
self.__dict__)
15131530

15141531
def __init__(self, *args, **kwargs):
15151532
ba = bound_init_signature.bind(*args, **kwargs)
@@ -1519,6 +1536,10 @@ def __init__(self, *args, **kwargs):
15191536
self._scale = scale_cls(axis=None, **ba.arguments)
15201537
self._trf = self._scale.get_transform()
15211538

1539+
__init__.__signature__ = bound_init_signature.replace(parameters=[
1540+
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD),
1541+
*bound_init_signature.parameters.values()])
1542+
15221543
def __call__(self, value, clip=None):
15231544
value, is_scalar = self.process_value(value)
15241545
if self.vmin is None or self.vmax is None:
@@ -1566,17 +1587,23 @@ def autoscale_None(self, A):
15661587
in_trf_domain = np.extract(np.isfinite(self._trf.transform(A)), A)
15671588
return super().autoscale_None(in_trf_domain)
15681589

1569-
Norm.__name__ = (f"{scale_cls.__name__}Norm" if base_norm_cls is Normalize
1570-
else base_norm_cls.__name__)
1571-
Norm.__qualname__ = base_norm_cls.__qualname__
1590+
Norm.__name__ = (
1591+
f"{scale_cls.__name__}Norm" if base_norm_cls is Normalize
1592+
else base_norm_cls.__name__)
1593+
Norm.__qualname__ = (
1594+
f"{scale_cls.__qualname__}Norm" if base_norm_cls is Normalize
1595+
else base_norm_cls.__qualname__)
15721596
Norm.__module__ = base_norm_cls.__module__
15731597
Norm.__doc__ = base_norm_cls.__doc__
1574-
Norm.__init__.__signature__ = bound_init_signature.replace(parameters=[
1575-
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD),
1576-
*bound_init_signature.parameters.values()])
1598+
15771599
return Norm
15781600

15791601

1602+
def _picklable_norm_constructor(*args, **kwargs):
1603+
cls = _make_norm_from_scale(*args, **kwargs)
1604+
return cls.__new__(cls)
1605+
1606+
15801607
@make_norm_from_scale(
15811608
scale.FuncScale,
15821609
init=lambda functions, vmin=None, vmax=None, clip=False: None)

lib/matplotlib/tests/test_colors.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1480,4 +1480,4 @@ def test_norm_update_figs(fig_test, fig_ref):
14801480
def test_make_norm_from_scale_name():
14811481
logitnorm = mcolors.make_norm_from_scale(
14821482
mscale.LogitScale, mcolors.Normalize)
1483-
assert logitnorm.__name__ == "LogitScaleNorm"
1483+
assert logitnorm.__name__ == logitnorm.__qualname__ == "LogitScaleNorm"

lib/matplotlib/tests/test_pickle.py

+7
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,10 @@ def test_unpickle_canvas():
219219
def test_mpl_toolkits():
220220
ax = parasite_axes.host_axes([0, 0, 1, 1])
221221
assert type(pickle.loads(pickle.dumps(ax))) == parasite_axes.HostAxes
222+
223+
224+
def test_dynamic_norm():
225+
logit_norm_instance = mpl.colors.make_norm_from_scale(
226+
mpl.scale.LogitScale, mpl.colors.Normalize)()
227+
assert type(pickle.loads(pickle.dumps(logit_norm_instance))) \
228+
== type(logit_norm_instance)

0 commit comments

Comments
 (0)