diff --git a/doc/users/next_whats_new/collection_offsets.rst b/doc/users/next_whats_new/collection_offsets.rst new file mode 100644 index 000000000000..d53d349a38fb --- /dev/null +++ b/doc/users/next_whats_new/collection_offsets.rst @@ -0,0 +1,5 @@ +Setting collection offset transform after initialization +-------------------------------------------------------- +`.collections.Collection.set_offset_transform()` was added. + +Previously the offset transform could not be set after initialization. This can be helpful when creating a `.collections.Collection` outside an axes object and later adding it with `.Axes.add_collection()` and settings the offset transform to `.Axes.transData`. diff --git a/lib/matplotlib/collections.py b/lib/matplotlib/collections.py index 736b1f016331..619c62b5ca14 100644 --- a/lib/matplotlib/collections.py +++ b/lib/matplotlib/collections.py @@ -62,7 +62,6 @@ class Collection(artist.Artist, cm.ScalarMappable): ignoring those that were manually passed in. """ _offsets = np.zeros((0, 2)) - _transOffset = transforms.IdentityTransform() #: Either a list of 3x3 arrays or an Nx3x3 array (representing N #: transforms), suitable for the `all_transforms` argument to #: `~matplotlib.backend_bases.RendererBase.draw_path_collection`; @@ -194,20 +193,17 @@ def __init__(self, else: self._joinstyle = None + # default to zeros self._offsets = np.zeros((1, 2)) - # save if offsets passed in were none... - self._offsetsNone = offsets is None - self._uniform_offsets = None + if offsets is not None: offsets = np.asanyarray(offsets, float) # Broadcast (2,) -> (1, 2) but nothing else. if offsets.shape == (2,): offsets = offsets[None, :] - if transOffset is not None: - self._offsets = offsets - self._transOffset = transOffset - else: - self._uniform_offsets = offsets + self._offsets = offsets + + self._transOffset = transOffset self._path_effects = None self.update(kwargs) @@ -223,11 +219,23 @@ def get_transforms(self): return self._transforms def get_offset_transform(self): - t = self._transOffset - if (not isinstance(t, transforms.Transform) - and hasattr(t, '_as_mpl_transform')): - t = t._as_mpl_transform(self.axes) - return t + """Return the `.Transform` instance used by this artist offset.""" + if self._transOffset is None: + self._transOffset = transforms.IdentityTransform() + elif (not isinstance(self._transOffset, transforms.Transform) + and hasattr(self._transOffset, '_as_mpl_transform')): + self._transOffset = self._transOffset._as_mpl_transform(self.axes) + return self._transOffset + + def set_offset_transform(self, transOffset): + """ + Set the artist offset transform. + + Parameters + ---------- + transOffset : `.Transform` + """ + self._transOffset = transOffset def get_datalim(self, transData): # Calculate the data limits and return them as a `.Bbox`. @@ -248,8 +256,8 @@ def get_datalim(self, transData): transform = self.get_transform() transOffset = self.get_offset_transform() - if (not self._offsetsNone and - not transOffset.contains_branch(transData)): + hasOffsets = np.any(self._offsets) # True if any non-zero offsets + if hasOffsets and not transOffset.contains_branch(transData): # if there are offsets but in some coords other than data, # then don't use them for autoscaling. return transforms.Bbox.null() @@ -279,7 +287,7 @@ def get_datalim(self, transData): self.get_transforms(), transOffset.transform_non_affine(offsets), transOffset.get_affine().frozen()) - if not self._offsetsNone: + if hasOffsets: # this is for collections that have their paths (shapes) # in physical, axes-relative, or figure-relative units # (i.e. like scatter). We can't uniquely set limits based on @@ -542,20 +550,12 @@ def set_offsets(self, offsets): offsets = np.asanyarray(offsets, float) if offsets.shape == (2,): # Broadcast (2,) -> (1, 2) but nothing else. offsets = offsets[None, :] - # This decision is based on how they are initialized above in __init__. - if self._uniform_offsets is None: - self._offsets = offsets - else: - self._uniform_offsets = offsets + self._offsets = offsets self.stale = True def get_offsets(self): """Return the offsets for the collection.""" - # This decision is based on how they are initialized above in __init__. - if self._uniform_offsets is None: - return self._offsets - else: - return self._uniform_offsets + return self._offsets def _get_default_linewidth(self): # This may be overridden in a subclass. @@ -1441,9 +1441,6 @@ def set_segments(self, segments): seg = np.asarray(seg, float) _segments.append(seg) - if self._uniform_offsets is not None: - _segments = self._add_offsets(_segments) - self._paths = [mpath.Path(_seg) for _seg in _segments] self.stale = True @@ -1474,19 +1471,6 @@ def get_segments(self): return segments - def _add_offsets(self, segs): - offsets = self._uniform_offsets - Nsegs = len(segs) - Noffs = offsets.shape[0] - if Noffs == 1: - for i in range(Nsegs): - segs[i] = segs[i] + i * offsets - else: - for i in range(Nsegs): - io = i % Noffs - segs[i] = segs[i] + offsets[io:io + 1] - return segs - def _get_default_linewidth(self): return mpl.rcParams['lines.linewidth'] diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 9aeb972d790b..a80c8d717416 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -1048,3 +1048,34 @@ def test_get_segments(): readback, = lc.get_segments() # these should comeback un-changed! assert np.all(segments == readback) + + +def test_set_offsets_late(): + identity = mtransforms.IdentityTransform() + sizes = [2] + + null = mcollections.CircleCollection(sizes=sizes) + + init = mcollections.CircleCollection(sizes=sizes, offsets=(10, 10)) + + late = mcollections.CircleCollection(sizes=sizes) + late.set_offsets((10, 10)) + + # Bbox.__eq__ doesn't compare bounds + null_bounds = null.get_datalim(identity).bounds + init_bounds = init.get_datalim(identity).bounds + late_bounds = late.get_datalim(identity).bounds + + # offsets and transform are applied when set after initialization + assert null_bounds != init_bounds + assert init_bounds == late_bounds + + +def test_set_offset_transform(): + skew = mtransforms.Affine2D().skew(2, 2) + init = mcollections.Collection([], transOffset=skew) + + late = mcollections.Collection([]) + late.set_offset_transform(skew) + + assert skew == init.get_offset_transform() == late.get_offset_transform()