Skip to content

Commit 2ed292d

Browse files
authored
Merge pull request #9991 from eric-wieser/normals-cleanup
MAINT: Use vectorization in plot_trisurf, simplifying greatly
2 parents 8d53ebf + e36bb0b commit 2ed292d

File tree

3 files changed

+32
-30
lines changed

3 files changed

+32
-30
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 13 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,47 +1997,30 @@ def plot_trisurf(self, *args, **kwargs):
19971997
args = args[1:]
19981998

19991999
triangles = tri.get_masked_triangles()
2000-
xt = tri.x[triangles][..., np.newaxis]
2001-
yt = tri.y[triangles][..., np.newaxis]
2002-
zt = z[triangles][..., np.newaxis]
2000+
xt = tri.x[triangles]
2001+
yt = tri.y[triangles]
2002+
zt = z[triangles]
20032003

2004-
verts = np.concatenate((xt, yt, zt), axis=2)
2005-
2006-
# Only need these vectors to shade if there is no cmap
2007-
if cmap is None and shade:
2008-
totpts = len(verts)
2009-
v1 = np.empty((totpts, 3))
2010-
v2 = np.empty((totpts, 3))
2011-
# This indexes the vertex points
2012-
which_pt = 0
2013-
2014-
colset = []
2015-
for i in xrange(len(verts)):
2016-
avgzsum = verts[i,0,2] + verts[i,1,2] + verts[i,2,2]
2017-
colset.append(avgzsum / 3.0)
2018-
2019-
# Only need vectors to shade if no cmap
2020-
if cmap is None and shade:
2021-
v1[which_pt] = np.array(verts[i,0]) - np.array(verts[i,1])
2022-
v2[which_pt] = np.array(verts[i,1]) - np.array(verts[i,2])
2023-
which_pt += 1
2024-
2025-
if cmap is None and shade:
2026-
normals = np.cross(v1, v2)
2027-
else:
2028-
normals = []
2004+
# verts = np.stack((xt, yt, zt), axis=-1)
2005+
verts = np.concatenate((
2006+
xt[..., np.newaxis], yt[..., np.newaxis], zt[..., np.newaxis]
2007+
), axis=-1)
20292008

20302009
polyc = art3d.Poly3DCollection(verts, *args, **kwargs)
20312010

20322011
if cmap:
2033-
colset = np.array(colset)
2034-
polyc.set_array(colset)
2012+
# average over the three points of each triangle
2013+
avg_z = verts[:, :, 2].mean(axis=1)
2014+
polyc.set_array(avg_z)
20352015
if vmin is not None or vmax is not None:
20362016
polyc.set_clim(vmin, vmax)
20372017
if norm is not None:
20382018
polyc.set_norm(norm)
20392019
else:
20402020
if shade:
2021+
v1 = verts[:, 0, :] - verts[:, 1, :]
2022+
v2 = verts[:, 1, :] - verts[:, 2, :]
2023+
normals = np.cross(v1, v2)
20412024
colset = self._shade_colors(color, normals)
20422025
else:
20432026
colset = color

lib/mpl_toolkits/tests/test_mplot3d.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,25 @@ def test_trisurf3d():
232232
ax.plot_trisurf(x, y, z, cmap=cm.jet, linewidth=0.2)
233233

234234

235+
@image_comparison(baseline_images=['trisurf3d_shaded'], remove_text=True,
236+
tol=0.03, extensions=['png'])
237+
def test_trisurf3d_shaded():
238+
n_angles = 36
239+
n_radii = 8
240+
radii = np.linspace(0.125, 1.0, n_radii)
241+
angles = np.linspace(0, 2*np.pi, n_angles, endpoint=False)
242+
angles = np.repeat(angles[..., np.newaxis], n_radii, axis=1)
243+
angles[:, 1::2] += np.pi/n_angles
244+
245+
x = np.append(0, (radii*np.cos(angles)).flatten())
246+
y = np.append(0, (radii*np.sin(angles)).flatten())
247+
z = np.sin(-x*y)
248+
249+
fig = plt.figure()
250+
ax = fig.gca(projection='3d')
251+
ax.plot_trisurf(x, y, z, color=[1, 0.5, 0], linewidth=0.2)
252+
253+
235254
@image_comparison(baseline_images=['wireframe3d'], remove_text=True)
236255
def test_wireframe3d():
237256
fig = plt.figure()

0 commit comments

Comments
 (0)