Skip to content

FIX: make sure scalarmappable updates are handled correctly in 3D #18929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 90 additions & 15 deletions lib/mpl_toolkits/mplot3d/art3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def do_3d_projection(self, renderer=None):
"""
Project the points according to renderer matrix.
"""
# see _update_scalarmappable docstring for why this must be here
_update_scalarmappable(self)
xyslist = [proj3d.proj_trans_points(points, self.axes.M)
for points in self._segments3d]
segments_2d = [np.column_stack([xs, ys]) for xs, ys, zs in xyslist]
Expand Down Expand Up @@ -486,6 +488,8 @@ def set_3d_properties(self, zs, zdir):

@cbook._delete_parameter('3.4', 'renderer')
def do_3d_projection(self, renderer=None):
# see _update_scalarmappable docstring for why this must be here
_update_scalarmappable(self)
xs, ys, zs = self._offsets3d
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)
Expand Down Expand Up @@ -592,6 +596,8 @@ def set_linewidth(self, lw):

@cbook._delete_parameter('3.4', 'renderer')
def do_3d_projection(self, renderer=None):
# see _update_scalarmappable docstring for why this must be here
_update_scalarmappable(self)
xs, ys, zs = self._offsets3d
vxs, vys, vzs, vis = proj3d.proj_transform_clip(xs, ys, zs,
self.axes.M)
Expand Down Expand Up @@ -635,6 +641,77 @@ def do_3d_projection(self, renderer=None):
return np.min(vzs) if vzs.size else np.nan


def _update_scalarmappable(sm):
"""
Update a 3D ScalarMappable.

With ScalarMappable objects if the data, colormap, or norm are
changed, we need to update the computed colors. This is handled
by the base class method update_scalarmappable. This method works
by, detecting if work needs to be done, and if so stashing it on
the ``self._facecolors`` attribute.

With 3D collections we internally sort the components so that
things that should be "in front" are rendered later to simulate
having a z-buffer (in addition to doing the projections). This is
handled in the ``do_3d_projection`` methods which are called from the
draw method of the 3D Axes. These methods:

- do the projection from 3D -> 2D
- internally sort based on depth
- stash the results of the above in the 2D analogs of state
- return the z-depth of the whole artist

the last step is so that we can, at the Axes level, sort the children by
depth.

The base `draw` method of the 2D artists unconditionally calls
update_scalarmappable and rely on the method's internal caching logic to
lazily evaluate.

These things together mean you can have the sequence of events:

- we create the artist, do the color mapping and stash the results
in a 3D specific state.
- change something about the ScalarMappable that marks it as in
need of an update (`ScalarMappable.changed` and friends).
- We call do_3d_projection and shuffle the stashed colors into the
2D version of face colors
- the draw method calls the update_scalarmappable method which
overwrites our shuffled colors
- we get a render that is wrong
- if we re-render (either with a second save or implicitly via
tight_layout / constrained_layout / bbox_inches='tight' (ex via
inline's defaults)) we again shuffle the 3D colors
- because the CM is not marked as changed update_scalarmappable is
a no-op and we get a correct looking render.

This function is an internal helper to:

- sort out if we need to do the color mapping at all (has data!)
- sort out if update_scalarmappable is going to be a no-op
- copy the data over from the 2D -> 3D version

This must be called first thing in do_3d_projection to make sure that
the correct colors get shuffled.

Parameters
----------
sm : ScalarMappable
The ScalarMappable to update and stash the 3D data from

"""
if sm._A is None:
return
copy_state = sm._update_dict['array']
ret = sm.update_scalarmappable()
if copy_state:
if sm._is_filled:
sm._facecolor3d = sm._facecolors
elif sm._is_stroked:
sm._edgecolor3d = sm._edgecolors


def patch_collection_2d_to_3d(col, zs=0, zdir='z', depthshade=True):
"""
Convert a :class:`~matplotlib.collections.PatchCollection` into a
Expand Down Expand Up @@ -757,8 +834,8 @@ def set_3d_properties(self):
self.update_scalarmappable()
self._sort_zpos = None
self.set_zsort('average')
self._facecolors3d = PolyCollection.get_facecolor(self)
self._edgecolors3d = PolyCollection.get_edgecolor(self)
self._facecolor3d = PolyCollection.get_facecolor(self)
self._edgecolor3d = PolyCollection.get_edgecolor(self)
self._alpha3d = PolyCollection.get_alpha(self)
self.stale = True

Expand All @@ -772,17 +849,15 @@ def do_3d_projection(self, renderer=None):
"""
Perform the 3D projection for this object.
"""
# FIXME: This may no longer be needed?
if self._A is not None:
self.update_scalarmappable()
self._facecolors3d = self._facecolors
# see _update_scalarmappable docstring for why this must be here
_update_scalarmappable(self)

txs, tys, tzs = proj3d._proj_transform_vec(self._vec, self.axes.M)
xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices]

# This extra fuss is to re-order face / edge colors
cface = self._facecolors3d
cedge = self._edgecolors3d
cface = self._facecolor3d
cedge = self._edgecolor3d
if len(cface) != len(xyzlist):
cface = cface.repeat(len(xyzlist), axis=0)
if len(cedge) != len(xyzlist):
Expand All @@ -807,8 +882,8 @@ def do_3d_projection(self, renderer=None):
else:
PolyCollection.set_verts(self, segments_2d, self._closed)

if len(self._edgecolors3d) != len(cface):
self._edgecolors2d = self._edgecolors3d
if len(self._edgecolor3d) != len(cface):
self._edgecolors2d = self._edgecolor3d

# Return zorder value
if self._sort_zpos is not None:
Expand All @@ -826,24 +901,24 @@ def do_3d_projection(self, renderer=None):
def set_facecolor(self, colors):
# docstring inherited
super().set_facecolor(colors)
self._facecolors3d = PolyCollection.get_facecolor(self)
self._facecolor3d = PolyCollection.get_facecolor(self)

def set_edgecolor(self, colors):
# docstring inherited
super().set_edgecolor(colors)
self._edgecolors3d = PolyCollection.get_edgecolor(self)
self._edgecolor3d = PolyCollection.get_edgecolor(self)

def set_alpha(self, alpha):
# docstring inherited
artist.Artist.set_alpha(self, alpha)
try:
self._facecolors3d = mcolors.to_rgba_array(
self._facecolors3d, self._alpha)
self._facecolor3d = mcolors.to_rgba_array(
self._facecolor3d, self._alpha)
except (AttributeError, TypeError, IndexError):
pass
try:
self._edgecolors = mcolors.to_rgba_array(
self._edgecolors3d, self._alpha)
self._edgecolor3d, self._alpha)
except (AttributeError, TypeError, IndexError):
pass
self.stale = True
Expand Down
24 changes: 23 additions & 1 deletion lib/mpl_toolkits/tests/test_mplot3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def test_bar3d_lightsource():
# the top facecolors compared to the default, and that those colors are
# precisely the colors from the colormap, due to the illumination parallel
# to the z-axis.
np.testing.assert_array_equal(color, collection._facecolors3d[1::6])
np.testing.assert_array_equal(color, collection._facecolor3d[1::6])


@mpl3d_image_comparison(['contour3d.png'])
Expand Down Expand Up @@ -1302,3 +1302,25 @@ def convert_lim(dmin, dmax):
assert x_center != pytest.approx(x_center0)
assert y_center != pytest.approx(y_center0)
assert z_center != pytest.approx(z_center0)


@pytest.mark.style('default')
@check_figures_equal(extensions=["png"])
def test_scalarmap_update(fig_test, fig_ref):

x, y, z = np.array((list(itertools.product(*[np.arange(0, 5, 1),
np.arange(0, 5, 1),
np.arange(0, 5, 1)])))).T
c = x + y

# test
ax_test = fig_test.add_subplot(111, projection='3d')
sc_test = ax_test.scatter(x, y, z, c=c, s=40, cmap='viridis')
# force a draw
fig_test.canvas.draw()
# mark it as "stale"
sc_test.changed()

# ref
ax_ref = fig_ref.add_subplot(111, projection='3d')
sc_ref = ax_ref.scatter(x, y, z, c=c, s=40, cmap='viridis')