From 5d7d9169e8ed6cb056066b2cfb67e2f76da3794c Mon Sep 17 00:00:00 2001 From: Alec Vercruysse Date: Thu, 29 Feb 2024 22:52:58 -0800 Subject: [PATCH] fix quiver3D arrow colors --- lib/mpl_toolkits/mplot3d/axes3d.py | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 4d8c7e6b8e84..467ad32e80fd 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -2984,15 +2984,10 @@ def calc_arrows(UVW): UVW = np.column_stack(input_args[3:]).astype(float) # Normalize rows of UVW - norm = np.linalg.norm(UVW, axis=1) - - # If any row of UVW is all zeros, don't make a quiver for it - mask = norm > 0 - XYZ = XYZ[mask] if normalize: - UVW = UVW[mask] / norm[mask].reshape((-1, 1)) - else: - UVW = UVW[mask] + norm = np.linalg.norm(UVW, axis=1) + norm[norm == 0] = 1 + UVW = UVW / norm.reshape((-1, 1)) if len(XYZ) > 0: # compute the shaft lines all at once with an outer product @@ -3003,10 +2998,13 @@ def calc_arrows(UVW): heads = shafts[:, :1] - np.multiply.outer(arrow_dt, head_dirs) # stack left and right head lines together heads = heads.reshape((len(arrow_dt), -1, 3)) - # transpose to get a list of lines + # transpose to get a list of lines, n heads x 2 points x 3 dimensions heads = heads.swapaxes(0, 1) - - lines = [*shafts, *heads] + # combine into a single line: n heads, 5 points, 3 dimensions + lines = np.empty((len(shafts), 5, 3), dtype=shafts.dtype) + lines[:, 0:2] = shafts[:, ::-1] + lines[:, 2:4] = heads[::2, ::-1] + lines[:, 4] = heads[1::2, 1] else: lines = []