Skip to content

Fix collection offsets #20717

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 6, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions doc/users/next_whats_new/collection_offsets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Setting collection offset transform after initialization
--------------------------------------------------------
`.collections.Collection.set_offset_transform()` was added.

Previously the offset transform could not be set after initialization. This can be helpful when creating a `.collections.Collection` outside an axes object and later adding it with `.Axes.add_collection()` and settings the offset transform to `.Axes.transData`.
70 changes: 27 additions & 43 deletions lib/matplotlib/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ class Collection(artist.Artist, cm.ScalarMappable):
ignoring those that were manually passed in.
"""
_offsets = np.zeros((0, 2))
_transOffset = transforms.IdentityTransform()
#: Either a list of 3x3 arrays or an Nx3x3 array (representing N
#: transforms), suitable for the `all_transforms` argument to
#: `~matplotlib.backend_bases.RendererBase.draw_path_collection`;
Expand Down Expand Up @@ -194,20 +193,17 @@ def __init__(self,
else:
self._joinstyle = None

# default to zeros
self._offsets = np.zeros((1, 2))
# save if offsets passed in were none...
self._offsetsNone = offsets is None
self._uniform_offsets = None

if offsets is not None:
offsets = np.asanyarray(offsets, float)
# Broadcast (2,) -> (1, 2) but nothing else.
if offsets.shape == (2,):
offsets = offsets[None, :]
if transOffset is not None:
self._offsets = offsets
self._transOffset = transOffset
else:
self._uniform_offsets = offsets
self._offsets = offsets

self._transOffset = transOffset

self._path_effects = None
self.update(kwargs)
Expand All @@ -223,11 +219,23 @@ def get_transforms(self):
return self._transforms

def get_offset_transform(self):
t = self._transOffset
if (not isinstance(t, transforms.Transform)
and hasattr(t, '_as_mpl_transform')):
t = t._as_mpl_transform(self.axes)
return t
"""Return the `.Transform` instance used by this artist offset."""
if self._transOffset is None:
self._transOffset = transforms.IdentityTransform()
elif (not isinstance(self._transOffset, transforms.Transform)
and hasattr(self._transOffset, '_as_mpl_transform')):
self._transOffset = self._transOffset._as_mpl_transform(self.axes)
return self._transOffset

def set_offset_transform(self, transOffset):
"""
Set the artist offset transform.

Parameters
----------
transOffset : `.Transform`
"""
self._transOffset = transOffset

def get_datalim(self, transData):
# Calculate the data limits and return them as a `.Bbox`.
Expand All @@ -248,8 +256,8 @@ def get_datalim(self, transData):

transform = self.get_transform()
transOffset = self.get_offset_transform()
if (not self._offsetsNone and
not transOffset.contains_branch(transData)):
hasOffsets = np.any(self._offsets) # True if any non-zero offsets
if hasOffsets and not transOffset.contains_branch(transData):
# if there are offsets but in some coords other than data,
# then don't use them for autoscaling.
return transforms.Bbox.null()
Expand Down Expand Up @@ -279,7 +287,7 @@ def get_datalim(self, transData):
self.get_transforms(),
transOffset.transform_non_affine(offsets),
transOffset.get_affine().frozen())
if not self._offsetsNone:
if hasOffsets:
# this is for collections that have their paths (shapes)
# in physical, axes-relative, or figure-relative units
# (i.e. like scatter). We can't uniquely set limits based on
Expand Down Expand Up @@ -542,20 +550,12 @@ def set_offsets(self, offsets):
offsets = np.asanyarray(offsets, float)
if offsets.shape == (2,): # Broadcast (2,) -> (1, 2) but nothing else.
offsets = offsets[None, :]
# This decision is based on how they are initialized above in __init__.
if self._uniform_offsets is None:
self._offsets = offsets
else:
self._uniform_offsets = offsets
self._offsets = offsets
self.stale = True

def get_offsets(self):
"""Return the offsets for the collection."""
# This decision is based on how they are initialized above in __init__.
if self._uniform_offsets is None:
return self._offsets
else:
return self._uniform_offsets
return self._offsets

def _get_default_linewidth(self):
# This may be overridden in a subclass.
Expand Down Expand Up @@ -1441,9 +1441,6 @@ def set_segments(self, segments):
seg = np.asarray(seg, float)
_segments.append(seg)

if self._uniform_offsets is not None:
_segments = self._add_offsets(_segments)

self._paths = [mpath.Path(_seg) for _seg in _segments]
self.stale = True

Expand Down Expand Up @@ -1474,19 +1471,6 @@ def get_segments(self):

return segments

def _add_offsets(self, segs):
offsets = self._uniform_offsets
Nsegs = len(segs)
Noffs = offsets.shape[0]
if Noffs == 1:
for i in range(Nsegs):
segs[i] = segs[i] + i * offsets
else:
for i in range(Nsegs):
io = i % Noffs
segs[i] = segs[i] + offsets[io:io + 1]
return segs

def _get_default_linewidth(self):
return mpl.rcParams['lines.linewidth']

Expand Down
31 changes: 31 additions & 0 deletions lib/matplotlib/tests/test_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,3 +1048,34 @@ def test_get_segments():
readback, = lc.get_segments()
# these should comeback un-changed!
assert np.all(segments == readback)


def test_set_offsets_late():
identity = mtransforms.IdentityTransform()
sizes = [2]

null = mcollections.CircleCollection(sizes=sizes)

init = mcollections.CircleCollection(sizes=sizes, offsets=(10, 10))

late = mcollections.CircleCollection(sizes=sizes)
late.set_offsets((10, 10))

# Bbox.__eq__ doesn't compare bounds
null_bounds = null.get_datalim(identity).bounds
init_bounds = init.get_datalim(identity).bounds
late_bounds = late.get_datalim(identity).bounds

# offsets and transform are applied when set after initialization
assert null_bounds != init_bounds
assert init_bounds == late_bounds


def test_set_offset_transform():
skew = mtransforms.Affine2D().skew(2, 2)
init = mcollections.Collection([], transOffset=skew)

late = mcollections.Collection([])
late.set_offset_transform(skew)

assert skew == init.get_offset_transform() == late.get_offset_transform()