Skip to content

Fix picklability of make_norm_from_scale norms. #21916

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 34 additions & 7 deletions lib/matplotlib/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1507,9 +1507,26 @@ class norm_cls(Normalize):

if init is None:
def init(vmin=None, vmax=None, clip=False): pass
bound_init_signature = inspect.signature(init)

return _make_norm_from_scale(
scale_cls, base_norm_cls, inspect.signature(init))


@functools.lru_cache(None)
def _make_norm_from_scale(scale_cls, base_norm_cls, bound_init_signature):
"""
Helper for `make_norm_from_scale`.

This function is split out so that it takes a signature object as third
argument (as signatures are picklable, contrary to arbitrary lambdas);
caching is also used so that different unpickles reuse the same class.
"""

class Norm(base_norm_cls):
def __reduce__(self):
return (_picklable_norm_constructor,
(scale_cls, base_norm_cls, bound_init_signature),
self.__dict__)

def __init__(self, *args, **kwargs):
ba = bound_init_signature.bind(*args, **kwargs)
Expand All @@ -1519,6 +1536,10 @@ def __init__(self, *args, **kwargs):
self._scale = scale_cls(axis=None, **ba.arguments)
self._trf = self._scale.get_transform()

__init__.__signature__ = bound_init_signature.replace(parameters=[
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD),
*bound_init_signature.parameters.values()])

def __call__(self, value, clip=None):
value, is_scalar = self.process_value(value)
if self.vmin is None or self.vmax is None:
Expand Down Expand Up @@ -1566,17 +1587,23 @@ def autoscale_None(self, A):
in_trf_domain = np.extract(np.isfinite(self._trf.transform(A)), A)
return super().autoscale_None(in_trf_domain)

Norm.__name__ = (f"{scale_cls.__name__}Norm" if base_norm_cls is Normalize
else base_norm_cls.__name__)
Norm.__qualname__ = base_norm_cls.__qualname__
Norm.__name__ = (
f"{scale_cls.__name__}Norm" if base_norm_cls is Normalize
else base_norm_cls.__name__)
Norm.__qualname__ = (
f"{scale_cls.__qualname__}Norm" if base_norm_cls is Normalize
else base_norm_cls.__qualname__)
Norm.__module__ = base_norm_cls.__module__
Norm.__doc__ = base_norm_cls.__doc__
Norm.__init__.__signature__ = bound_init_signature.replace(parameters=[
inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD),
*bound_init_signature.parameters.values()])

return Norm


def _picklable_norm_constructor(*args):
cls = _make_norm_from_scale(*args)
return cls.__new__(cls)


@make_norm_from_scale(
scale.FuncScale,
init=lambda functions, vmin=None, vmax=None, clip=False: None)
Expand Down
2 changes: 1 addition & 1 deletion lib/matplotlib/tests/test_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1480,4 +1480,4 @@ def test_norm_update_figs(fig_test, fig_ref):
def test_make_norm_from_scale_name():
logitnorm = mcolors.make_norm_from_scale(
mscale.LogitScale, mcolors.Normalize)
assert logitnorm.__name__ == "LogitScaleNorm"
assert logitnorm.__name__ == logitnorm.__qualname__ == "LogitScaleNorm"
7 changes: 7 additions & 0 deletions lib/matplotlib/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,10 @@ def test_unpickle_canvas():
def test_mpl_toolkits():
ax = parasite_axes.host_axes([0, 0, 1, 1])
assert type(pickle.loads(pickle.dumps(ax))) == parasite_axes.HostAxes


def test_dynamic_norm():
logit_norm_instance = mpl.colors.make_norm_from_scale(
mpl.scale.LogitScale, mpl.colors.Normalize)()
assert type(pickle.loads(pickle.dumps(logit_norm_instance))) \
== type(logit_norm_instance)