Skip to content

Commit 8633bbd

Browse files
committed
Store vertices as (masked) arrays in Poly3DCollection
This removes a bunch of very expensive list comprehensions, dropping rendering times by half for some 3D surface plots.
1 parent ef9fc20 commit 8633bbd

File tree

2 files changed

+82
-30
lines changed

2 files changed

+82
-30
lines changed

lib/mpl_toolkits/mplot3d/art3d.py

+52-27
Original file line numberDiff line numberDiff line change
@@ -598,16 +598,33 @@ def set_zsort(self, zsort):
598598
self.stale = True
599599

600600
def get_vector(self, segments3d):
601-
"""Optimize points for projection."""
602-
if len(segments3d):
603-
xs, ys, zs = np.row_stack(segments3d).T
604-
else: # row_stack can't stack zero arrays.
605-
xs, ys, zs = [], [], []
606-
ones = np.ones(len(xs))
607-
self._vec = np.array([xs, ys, zs, ones])
601+
"""
602+
Optimize points for projection.
608603
609-
indices = [0, *np.cumsum([len(segment) for segment in segments3d])]
610-
self._segslices = [*map(slice, indices[:-1], indices[1:])]
604+
Parameters
605+
----------
606+
segments3d : NumPy array or list of NumPy arrays
607+
List of vertices of the boundary of every segment. If all paths are
608+
of equal length and this argument is a NumPy arrray, then it should
609+
be of shape (num_faces, num_vertices, 3).
610+
"""
611+
if isinstance(segments3d, np.ndarray):
612+
if segments3d.ndim != 3 or segments3d.shape[-1] != 3:
613+
raise ValueError("segments3d must be a MxNx3 array, but got " +
614+
"shape {}".format(segments3d.shape))
615+
self._segments = segments3d
616+
else:
617+
num_faces = len(segments3d)
618+
num_verts = np.fromiter(map(len, segments3d), dtype=np.intp)
619+
max_verts = num_verts.max(initial=0)
620+
padded = np.empty((num_faces, max_verts, 3))
621+
for i, face in enumerate(segments3d):
622+
padded[i, :len(face)] = face
623+
mask = np.arange(max_verts) >= num_verts[:, None]
624+
mask = mask[..., None] # add a component axis
625+
# ma.array does not broadcast the mask for us
626+
mask = np.broadcast_to(mask, padded.shape)
627+
self._segments = np.ma.array(padded, mask=mask)
611628

612629
def set_verts(self, verts, closed=True):
613630
"""Set 3D vertices."""
@@ -649,37 +666,45 @@ def do_3d_projection(self, renderer):
649666
self.update_scalarmappable()
650667
self._facecolors3d = self._facecolors
651668

652-
txs, tys, tzs = proj3d._proj_transform_vec(self._vec, renderer.M)
653-
xyzlist = [(txs[sl], tys[sl], tzs[sl]) for sl in self._segslices]
669+
psegments = proj3d._proj_transform_vectors(self._segments, renderer.M)
670+
is_masked = isinstance(psegments, np.ma.MaskedArray)
671+
num_faces = len(psegments)
654672

655673
# This extra fuss is to re-order face / edge colors
656674
cface = self._facecolors3d
657675
cedge = self._edgecolors3d
658-
if len(cface) != len(xyzlist):
659-
cface = cface.repeat(len(xyzlist), axis=0)
660-
if len(cedge) != len(xyzlist):
676+
if len(cface) != num_faces:
677+
cface = cface.repeat(num_faces, axis=0)
678+
if len(cedge) != num_faces:
661679
if len(cedge) == 0:
662680
cedge = cface
663681
else:
664-
cedge = cedge.repeat(len(xyzlist), axis=0)
682+
cedge = cedge.repeat(num_faces, axis=0)
665683

666-
# sort by depth (furthest drawn first)
667-
z_segments_2d = sorted(
668-
((self._zsortfunc(zs), np.column_stack([xs, ys]), fc, ec, idx)
669-
for idx, ((xs, ys, zs), fc, ec)
670-
in enumerate(zip(xyzlist, cface, cedge))),
671-
key=lambda x: x[0], reverse=True)
684+
face_z = self._zsortfunc(psegments[..., 2], axis=-1)
685+
if is_masked:
686+
# NOTE: Unpacking .data is safe here, because every face has to
687+
# contain a valid vertex.
688+
face_z = face_z.data
689+
face_order = np.argsort(face_z, axis=-1)[::-1]
672690

673-
segments_2d = [s for z, s, fc, ec, idx in z_segments_2d]
691+
segments_2d = psegments[face_order, :, :2]
674692
if self._codes3d is not None:
675-
codes = [self._codes3d[idx] for z, s, fc, ec, idx in z_segments_2d]
693+
if is_masked:
694+
# NOTE: We cannot assert the same on segments_2d, as it is a
695+
# result of advanced indexing, so its mask is a newly
696+
# allocated tensor. However, both of those asserts are
697+
# equivalent.
698+
assert psegments.mask.strides[-1] == 0
699+
segments_2d = [s.compressed().reshape(-1, 2) for s in segments_2d]
700+
codes = [self._codes3d[idx] for idx in face_order]
676701
PolyCollection.set_verts_and_codes(self, segments_2d, codes)
677702
else:
678703
PolyCollection.set_verts(self, segments_2d, self._closed)
679704

680-
self._facecolors2d = [fc for z, s, fc, ec, idx in z_segments_2d]
705+
self._facecolors2d = cface[face_order]
681706
if len(self._edgecolors3d) == len(cface):
682-
self._edgecolors2d = [ec for z, s, fc, ec, idx in z_segments_2d]
707+
self._edgecolors2d = cedge[face_order]
683708
else:
684709
self._edgecolors2d = self._edgecolors3d
685710

@@ -688,11 +713,11 @@ def do_3d_projection(self, renderer):
688713
zvec = np.array([[0], [0], [self._sort_zpos], [1]])
689714
ztrans = proj3d._proj_transform_vec(zvec, renderer.M)
690715
return ztrans[2][0]
691-
elif tzs.size > 0:
716+
elif psegments.size > 0:
692717
# FIXME: Some results still don't look quite right.
693718
# In particular, examine contourf3d_demo2.py
694719
# with az = -54 and elev = -45.
695-
return np.min(tzs)
720+
return np.min(psegments[..., 2])
696721
else:
697722
return np.nan
698723

lib/mpl_toolkits/mplot3d/proj3d.py

+30-3
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,8 @@ def ortho_transformation(zfront, zback):
103103

104104
def _proj_transform_vec(vec, M):
105105
vecw = np.dot(M, vec)
106-
w = vecw[3]
107106
# clip here..
108-
txs, tys, tzs = vecw[0]/w, vecw[1]/w, vecw[2]/w
109-
return txs, tys, tzs
107+
return vecw[:3] / vecw[3]
110108

111109

112110
def _proj_transform_vec_clip(vec, M):
@@ -146,6 +144,35 @@ def proj_transform(xs, ys, zs, M):
146144
transform = proj_transform
147145

148146

147+
def _proj_transform_vectors(vecs, M):
148+
"""Vector version of ``project_transform`` able to handle MaskedArrays.
149+
150+
Parameters
151+
----------
152+
vecs : ... x 3 np.ndarray or np.ma.MaskedArray
153+
Input vectors
154+
155+
M : 4 x 4 np.ndarray
156+
Projection matrix
157+
"""
158+
vecs_shape = vecs.shape
159+
vecs = vecs.reshape(-1, 3).T
160+
161+
is_masked = isinstance(vecs, np.ma.MaskedArray)
162+
if is_masked:
163+
mask = vecs.mask
164+
165+
vecs_pad = np.empty((vecs.shape[0] + 1,) + vecs.shape[1:])
166+
vecs_pad[:-1] = vecs
167+
vecs_pad[-1] = 1
168+
product = np.dot(M, vecs_pad)
169+
tvecs = product[:3] / product[3]
170+
171+
if is_masked:
172+
tvecs = np.ma.array(tvecs, mask=mask)
173+
return tvecs.T.reshape(vecs_shape)
174+
175+
149176
def proj_transform_clip(xs, ys, zs, M):
150177
"""
151178
Transform the points by the projection matrix

0 commit comments

Comments
 (0)