Skip to content

Commit 5d7d916

Browse files
fix quiver3D arrow colors
1 parent c2aa4ee commit 5d7d916

File tree

1 file changed

+9
-11
lines changed

1 file changed

+9
-11
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

+9-11
Original file line numberDiff line numberDiff line change
@@ -2984,15 +2984,10 @@ def calc_arrows(UVW):
29842984
UVW = np.column_stack(input_args[3:]).astype(float)
29852985

29862986
# Normalize rows of UVW
2987-
norm = np.linalg.norm(UVW, axis=1)
2988-
2989-
# If any row of UVW is all zeros, don't make a quiver for it
2990-
mask = norm > 0
2991-
XYZ = XYZ[mask]
29922987
if normalize:
2993-
UVW = UVW[mask] / norm[mask].reshape((-1, 1))
2994-
else:
2995-
UVW = UVW[mask]
2988+
norm = np.linalg.norm(UVW, axis=1)
2989+
norm[norm == 0] = 1
2990+
UVW = UVW / norm.reshape((-1, 1))
29962991

29972992
if len(XYZ) > 0:
29982993
# compute the shaft lines all at once with an outer product
@@ -3003,10 +2998,13 @@ def calc_arrows(UVW):
30032998
heads = shafts[:, :1] - np.multiply.outer(arrow_dt, head_dirs)
30042999
# stack left and right head lines together
30053000
heads = heads.reshape((len(arrow_dt), -1, 3))
3006-
# transpose to get a list of lines
3001+
# transpose to get a list of lines, n heads x 2 points x 3 dimensions
30073002
heads = heads.swapaxes(0, 1)
3008-
3009-
lines = [*shafts, *heads]
3003+
# combine into a single line: n heads, 5 points, 3 dimensions
3004+
lines = np.empty((len(shafts), 5, 3), dtype=shafts.dtype)
3005+
lines[:, 0:2] = shafts[:, ::-1]
3006+
lines[:, 2:4] = heads[::2, ::-1]
3007+
lines[:, 4] = heads[1::2, 1]
30103008
else:
30113009
lines = []
30123010

0 commit comments

Comments
 (0)