diff --git a/lib/mpl_toolkits/mplot3d/art3d.py b/lib/mpl_toolkits/mplot3d/art3d.py index 8b0669775eaf..f206213e9d65 100644 --- a/lib/mpl_toolkits/mplot3d/art3d.py +++ b/lib/mpl_toolkits/mplot3d/art3d.py @@ -737,6 +737,18 @@ def set_zsort(self, zsort): self._sort_zpos = None self.stale = True + def _zsortval(self, zs): + """ + Compute the value to use for z-sorting given the viewer z + coordinates of an object `zs`, with larger values drawn underneath + smaller values. + + This function should never return `nan`, and returns `np.inf` if no + non-nan value is computable. + """ + nans = np.isnan(zs) + return np.inf if nans.all() else self._zsortfunc(zs[~nans]) + def get_vector(self, segments3d): """Optimize points for projection.""" if len(segments3d): @@ -815,7 +827,7 @@ def do_3d_projection(self, renderer=None): if xyzlist: # sort by depth (furthest drawn first) z_segments_2d = sorted( - ((self._zsortfunc(zs), np.column_stack([xs, ys]), fc, ec, idx) + ((self._zsortval(zs), np.column_stack([xs, ys]), fc, ec, idx) for idx, ((xs, ys, zs), fc, ec) in enumerate(zip(xyzlist, cface, cedge))), key=lambda x: x[0], reverse=True)