Skip to content

MNT: Use WeakKeyDictionary and WeakSet in Grouper #25352

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions doc/api/next_api_changes/deprecations/25352-GL.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
``Grouper.clean()``
~~~~~~~~~~~~~~~~~~~

with no replacement. The Grouper class now cleans itself up automatically.
2 changes: 0 additions & 2 deletions lib/matplotlib/axes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,8 +1363,6 @@ def __clear(self):
self.xaxis.set_clip_path(self.patch)
self.yaxis.set_clip_path(self.patch)

self._shared_axes["x"].clean()
self._shared_axes["y"].clean()
if self._sharex is not None:
self.xaxis.set_visible(xaxis_visible)
self.patch.set_visible(patch_visible)
Expand Down
42 changes: 16 additions & 26 deletions lib/matplotlib/cbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,78 +786,68 @@ class Grouper:
"""

def __init__(self, init=()):
self._mapping = {weakref.ref(x): [weakref.ref(x)] for x in init}
self._mapping = weakref.WeakKeyDictionary(
{x: weakref.WeakSet([x]) for x in init})

def __getstate__(self):
return {
**vars(self),
# Convert weak refs to strong ones.
"_mapping": {k(): [v() for v in vs] for k, vs in self._mapping.items()},
"_mapping": {k: set(v) for k, v in self._mapping.items()},
}

def __setstate__(self, state):
vars(self).update(state)
# Convert strong refs to weak ones.
self._mapping = {weakref.ref(k): [*map(weakref.ref, vs)]
for k, vs in self._mapping.items()}
self._mapping = weakref.WeakKeyDictionary(
{k: weakref.WeakSet(v) for k, v in self._mapping.items()})

def __contains__(self, item):
return weakref.ref(item) in self._mapping
return item in self._mapping

@_api.deprecated("3.8", alternative="none, you no longer need to clean a Grouper")
def clean(self):
"""Clean dead weak references from the dictionary."""
mapping = self._mapping
to_drop = [key for key in mapping if key() is None]
for key in to_drop:
val = mapping.pop(key)
val.remove(key)

def join(self, a, *args):
"""
Join given arguments into the same set. Accepts one or more arguments.
"""
mapping = self._mapping
set_a = mapping.setdefault(weakref.ref(a), [weakref.ref(a)])
set_a = mapping.setdefault(a, weakref.WeakSet([a]))

for arg in args:
set_b = mapping.get(weakref.ref(arg), [weakref.ref(arg)])
set_b = mapping.get(arg, weakref.WeakSet([arg]))
if set_b is not set_a:
if len(set_b) > len(set_a):
set_a, set_b = set_b, set_a
set_a.extend(set_b)
set_a.update(set_b)
for elem in set_b:
mapping[elem] = set_a

self.clean()

def joined(self, a, b):
"""Return whether *a* and *b* are members of the same set."""
self.clean()
return (self._mapping.get(weakref.ref(a), object())
is self._mapping.get(weakref.ref(b)))
return (self._mapping.get(a, object()) is self._mapping.get(b))

def remove(self, a):
self.clean()
set_a = self._mapping.pop(weakref.ref(a), None)
set_a = self._mapping.pop(a, None)
if set_a:
set_a.remove(weakref.ref(a))
set_a.remove(a)

def __iter__(self):
"""
Iterate over each of the disjoint sets as a list.

The iterator is invalid if interleaved with calls to join().
"""
self.clean()
unique_groups = {id(group): group for group in self._mapping.values()}
for group in unique_groups.values():
yield [x() for x in group]
yield [x for x in group]

def get_siblings(self, a):
"""Return all of the items joined with *a*, including itself."""
self.clean()
siblings = self._mapping.get(weakref.ref(a), [weakref.ref(a)])
return [x() for x in siblings]
siblings = self._mapping.get(a, [a])
return [x for x in siblings]


class GrouperView:
Expand Down
7 changes: 3 additions & 4 deletions lib/matplotlib/tests/test_cbook.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import itertools
import pickle

from weakref import ref
from unittest.mock import patch, Mock

from datetime import datetime, date, timedelta
Expand Down Expand Up @@ -590,11 +589,11 @@ class Dummy:
mapping = g._mapping

for o in objs:
assert ref(o) in mapping
assert o in mapping

base_set = mapping[ref(objs[0])]
base_set = mapping[objs[0]]
for o in objs[1:]:
assert mapping[ref(o)] is base_set
assert mapping[o] is base_set


def test_flatiter():
Expand Down
3 changes: 0 additions & 3 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,6 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True,
_tight = self._tight = bool(tight)

if scalex and self.get_autoscalex_on():
self._shared_axes["x"].clean()
x0, x1 = self.xy_dataLim.intervalx
xlocator = self.xaxis.get_major_locator()
x0, x1 = xlocator.nonsingular(x0, x1)
Expand All @@ -653,7 +652,6 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True,
self.set_xbound(x0, x1)

if scaley and self.get_autoscaley_on():
self._shared_axes["y"].clean()
y0, y1 = self.xy_dataLim.intervaly
ylocator = self.yaxis.get_major_locator()
y0, y1 = ylocator.nonsingular(y0, y1)
Expand All @@ -666,7 +664,6 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True,
self.set_ybound(y0, y1)

if scalez and self.get_autoscalez_on():
self._shared_axes["z"].clean()
z0, z1 = self.zz_dataLim.intervalx
zlocator = self.zaxis.get_major_locator()
z0, z1 = zlocator.nonsingular(z0, z1)
Expand Down