From c38c4055b71239e6170eb03a379cd4fe004003d0 Mon Sep 17 00:00:00 2001 From: Greg Lucas Date: Sun, 26 Feb 2023 19:20:50 -0700 Subject: [PATCH] MNT: Use WeakKeyDictionary and WeakSet in Grouper Rather than handling the weakrefs ourselves, just use the builtin WeakKeyDictionary instead. This will automatically remove dead references meaning we can remove the clean() method. --- .../deprecations/25352-GL.rst | 4 ++ lib/matplotlib/axes/_base.py | 2 - lib/matplotlib/cbook.py | 42 +++++++------------ lib/matplotlib/tests/test_cbook.py | 7 ++-- lib/mpl_toolkits/mplot3d/axes3d.py | 3 -- 5 files changed, 23 insertions(+), 35 deletions(-) create mode 100644 doc/api/next_api_changes/deprecations/25352-GL.rst diff --git a/doc/api/next_api_changes/deprecations/25352-GL.rst b/doc/api/next_api_changes/deprecations/25352-GL.rst new file mode 100644 index 000000000000..e7edd57a6453 --- /dev/null +++ b/doc/api/next_api_changes/deprecations/25352-GL.rst @@ -0,0 +1,4 @@ +``Grouper.clean()`` +~~~~~~~~~~~~~~~~~~~ + +with no replacement. The Grouper class now cleans itself up automatically. diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 8e348fea4675..7893d7434f0d 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -1363,8 +1363,6 @@ def __clear(self): self.xaxis.set_clip_path(self.patch) self.yaxis.set_clip_path(self.patch) - self._shared_axes["x"].clean() - self._shared_axes["y"].clean() if self._sharex is not None: self.xaxis.set_visible(xaxis_visible) self.patch.set_visible(patch_visible) diff --git a/lib/matplotlib/cbook.py b/lib/matplotlib/cbook.py index 1a64331e201d..3c97e26f6316 100644 --- a/lib/matplotlib/cbook.py +++ b/lib/matplotlib/cbook.py @@ -786,61 +786,53 @@ class Grouper: """ def __init__(self, init=()): - self._mapping = {weakref.ref(x): [weakref.ref(x)] for x in init} + self._mapping = weakref.WeakKeyDictionary( + {x: weakref.WeakSet([x]) for x in init}) def __getstate__(self): return { **vars(self), # Convert weak refs to strong ones. - "_mapping": {k(): [v() for v in vs] for k, vs in self._mapping.items()}, + "_mapping": {k: set(v) for k, v in self._mapping.items()}, } def __setstate__(self, state): vars(self).update(state) # Convert strong refs to weak ones. - self._mapping = {weakref.ref(k): [*map(weakref.ref, vs)] - for k, vs in self._mapping.items()} + self._mapping = weakref.WeakKeyDictionary( + {k: weakref.WeakSet(v) for k, v in self._mapping.items()}) def __contains__(self, item): - return weakref.ref(item) in self._mapping + return item in self._mapping + @_api.deprecated("3.8", alternative="none, you no longer need to clean a Grouper") def clean(self): """Clean dead weak references from the dictionary.""" - mapping = self._mapping - to_drop = [key for key in mapping if key() is None] - for key in to_drop: - val = mapping.pop(key) - val.remove(key) def join(self, a, *args): """ Join given arguments into the same set. Accepts one or more arguments. """ mapping = self._mapping - set_a = mapping.setdefault(weakref.ref(a), [weakref.ref(a)]) + set_a = mapping.setdefault(a, weakref.WeakSet([a])) for arg in args: - set_b = mapping.get(weakref.ref(arg), [weakref.ref(arg)]) + set_b = mapping.get(arg, weakref.WeakSet([arg])) if set_b is not set_a: if len(set_b) > len(set_a): set_a, set_b = set_b, set_a - set_a.extend(set_b) + set_a.update(set_b) for elem in set_b: mapping[elem] = set_a - self.clean() - def joined(self, a, b): """Return whether *a* and *b* are members of the same set.""" - self.clean() - return (self._mapping.get(weakref.ref(a), object()) - is self._mapping.get(weakref.ref(b))) + return (self._mapping.get(a, object()) is self._mapping.get(b)) def remove(self, a): - self.clean() - set_a = self._mapping.pop(weakref.ref(a), None) + set_a = self._mapping.pop(a, None) if set_a: - set_a.remove(weakref.ref(a)) + set_a.remove(a) def __iter__(self): """ @@ -848,16 +840,14 @@ def __iter__(self): The iterator is invalid if interleaved with calls to join(). """ - self.clean() unique_groups = {id(group): group for group in self._mapping.values()} for group in unique_groups.values(): - yield [x() for x in group] + yield [x for x in group] def get_siblings(self, a): """Return all of the items joined with *a*, including itself.""" - self.clean() - siblings = self._mapping.get(weakref.ref(a), [weakref.ref(a)]) - return [x() for x in siblings] + siblings = self._mapping.get(a, [a]) + return [x for x in siblings] class GrouperView: diff --git a/lib/matplotlib/tests/test_cbook.py b/lib/matplotlib/tests/test_cbook.py index aa5c999b7079..da9c187a323a 100644 --- a/lib/matplotlib/tests/test_cbook.py +++ b/lib/matplotlib/tests/test_cbook.py @@ -1,7 +1,6 @@ import itertools import pickle -from weakref import ref from unittest.mock import patch, Mock from datetime import datetime, date, timedelta @@ -590,11 +589,11 @@ class Dummy: mapping = g._mapping for o in objs: - assert ref(o) in mapping + assert o in mapping - base_set = mapping[ref(objs[0])] + base_set = mapping[objs[0]] for o in objs[1:]: - assert mapping[ref(o)] is base_set + assert mapping[o] is base_set def test_flatiter(): diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index cb31aca6459e..67f438f107dd 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -640,7 +640,6 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True, _tight = self._tight = bool(tight) if scalex and self.get_autoscalex_on(): - self._shared_axes["x"].clean() x0, x1 = self.xy_dataLim.intervalx xlocator = self.xaxis.get_major_locator() x0, x1 = xlocator.nonsingular(x0, x1) @@ -653,7 +652,6 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True, self.set_xbound(x0, x1) if scaley and self.get_autoscaley_on(): - self._shared_axes["y"].clean() y0, y1 = self.xy_dataLim.intervaly ylocator = self.yaxis.get_major_locator() y0, y1 = ylocator.nonsingular(y0, y1) @@ -666,7 +664,6 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True, self.set_ybound(y0, y1) if scalez and self.get_autoscalez_on(): - self._shared_axes["z"].clean() z0, z1 = self.zz_dataLim.intervalx zlocator = self.zaxis.get_major_locator() z0, z1 = zlocator.nonsingular(z0, z1)