diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 33a08816cf69..492f14ee37b4 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -2729,12 +2729,30 @@ def calc_arrow(uvw, angle=15): # transpose to get a list of lines heads = heads.swapaxes(0, 1) - lines = [*shafts, *heads] + # create seperate lists for lines, left arrow head lines, + # and right arrow head lines from the computed lines and + # heads to be used for generating gradients + lines = shafts + arrows1 = [] + arrows2 = [] + for a in range(len(heads)): + if a % 2 == 0: + arrows1.append(heads[a]) + else: + arrows2.append(heads[a]) else: + arrows1 = [] + arrows2 = [] lines = [] + # generate 3D lines with gradient for body, left arrow head lines, + # and right arrow head lines linec = art3d.Line3DCollection(lines, *args[argi:], **kwargs) + arrow1c = art3d.Line3DCollection(arrows1, *args[argi:], **kwargs) + arrow2c = art3d.Line3DCollection(arrows2, *args[argi:], **kwargs) self.add_collection(linec) + self.add_collection(arrow1c) + self.add_collection(arrow2c) self.auto_scale_xyz(XYZ[:, 0], XYZ[:, 1], XYZ[:, 2], had_data)