diff --git a/doc/release/next_whats_new/arrow3d.rst b/doc/release/next_whats_new/arrow3d.rst new file mode 100644 index 000000000000..3728ea68375d --- /dev/null +++ b/doc/release/next_whats_new/arrow3d.rst @@ -0,0 +1,5 @@ +New ``arrow3d`` Method +---------------------- + +The new ``arrow3d`` method for `~.Axes3D` allows users to plot 3D arrows, +which can be used for representing vectors or directional data. diff --git a/galleries/examples/mplot3d/arrow3d.py b/galleries/examples/mplot3d/arrow3d.py new file mode 100644 index 000000000000..c8194a1ceeb3 --- /dev/null +++ b/galleries/examples/mplot3d/arrow3d.py @@ -0,0 +1,37 @@ +""" +============= +3D arrow plot +============= + +Demonstrates plotting arrows in a 3D space. + +Here we plot two arrows from the same start point to different +end points. The properties of the second arrow is changed by passing +additional parameters other than ``end`` and ``start`` to +`.patches.FancyArrowPatch`. +""" + +import matplotlib.pyplot as plt +import numpy as np + +fig = plt.figure() +ax = fig.add_subplot(111, projection='3d') + +# Define the start and end points of the arrow +start = np.array([0, 0, 0]) +end = np.array([1, 1, 1]) + +# Create the arrow +ax.arrow3d(end, start) + +end1 = np.array([1, 2, 3]) +# Passing additional keyword arguments to control properties of the arrow. +# If the `start` parameter is not passed, the arrow is drawn from (0, 0, 0). +ax.arrow3d(end1, mutation_scale=20, color='r', arrowstyle='->', linewidth=2) + +plt.show() + +# %% +# .. tags:: +# plot-type: 3D, +# level: beginner diff --git a/lib/mpl_toolkits/mplot3d/art3d.py b/lib/mpl_toolkits/mplot3d/art3d.py index e051e44fb23d..5a28db548979 100644 --- a/lib/mpl_toolkits/mplot3d/art3d.py +++ b/lib/mpl_toolkits/mplot3d/art3d.py @@ -18,7 +18,7 @@ path as mpath, rcParams) from matplotlib.collections import ( Collection, LineCollection, PolyCollection, PatchCollection, PathCollection) -from matplotlib.patches import Patch +from matplotlib.patches import Patch, FancyArrowPatch from . import proj3d @@ -1665,3 +1665,25 @@ def norm(x): colors = np.asanyarray(color).copy() return colors + + +class Arrow3D(FancyArrowPatch): + """ + 3D FancyArrowPatch object. + """ + def __init__(self, posA, posB, *args, **kwargs): + """ + Parameters + ---------- + posA, posB : array-like + The coordinates of the arrow's start and end points. + """ + super().__init__((0,0), (0,0), *args, **kwargs) + self._verts3d = list(zip(posA, posB)) + + def do_3d_projection(self, renderer=None): + """Projects the points according to the renderer matrix.""" + xs3d, ys3d, zs3d = self._verts3d + xs, ys, zs = proj3d.proj_transform(xs3d, ys3d, zs3d, self.axes.M) + self.set_positions((xs[0], ys[0]), (xs[1], ys[1])) + return np.min(zs) diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index 32da8dfde7aa..b41d4c6979b5 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -4057,6 +4057,53 @@ def stem(self, x, y, z, *, linefmt='C0-', markerfmt='C0o', basefmt='C3-', stem3D = stem + @_preprocess_data() + def arrow3d(self, end, start=None, **kwargs): + """ + 3D plot of a single arrow + + Parameters + ---------- + end : 1D array + an array of shape (3,). + + start : 1D array, default: (0,0,0) + an array of shape (3,). + + data : indexable object, optional + DATA_PARAMETER_PLACEHOLDER + + **kwargs + All other keyword arguments are passed on to + `~mpl_toolkits.mplot3d.art3d.Arrow3D`. + + Returns + ------- + arrow : `~mpl_toolkits.mplot3d.art3d.Arrow3D` + + """ + had_data = self.has_data() + + if start is None: + start = np.zeros_like(end) + if np.shape(end) != (3,): + raise ValueError("end must be an array of length 3") + if np.shape(start) != (3,): + raise ValueError("start must be an array of length 3") + + # Set default arrow properties and update with any additional keyword args + arrow_props = dict( + mutation_scale=20, arrowstyle="-|>", shrinkA=0, shrinkB=0 + ) + arrow_props.update(kwargs) + + arrow = art3d.Arrow3D(start, end, **arrow_props) + self.add_artist(arrow) + xs, ys, zs = list(zip(start, end)) + self.auto_scale_xyz(xs, ys, zs, had_data) + + return arrow + def get_test_data(delta=0.05): """Return a tuple X, Y, Z with a test data set.""" diff --git a/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/arrow3d_custom_props.png b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/arrow3d_custom_props.png new file mode 100644 index 000000000000..ac70fb025bfe Binary files /dev/null and b/lib/mpl_toolkits/mplot3d/tests/baseline_images/test_axes3d/arrow3d_custom_props.png differ diff --git a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py index e38df4f80ba4..d1fe8363120f 100644 --- a/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py +++ b/lib/mpl_toolkits/mplot3d/tests/test_axes3d.py @@ -2740,3 +2740,30 @@ def test_axes3d_set_aspect_deperecated_params(): with pytest.raises(ValueError, match="adjustable"): ax.set_aspect('equal', adjustable='invalid_value') + + +@check_figures_equal() +def test_arrow3d_default(fig_test, fig_ref): + ax_ref = fig_ref.add_subplot(projection='3d') + start = [0, 0, 0] + end = [1, 2, 3] + ax_ref.arrow3d(end, start) + + ax_test = fig_test.add_subplot(projection='3d') + ax_test.arrow3d(end) + + +@mpl3d_image_comparison(['arrow3d_custom_props.png'], style='mpl20', + tol=0.02 if sys.platform == 'darwin' else 0) +def test_arrow3d_custom_props(): + fig = plt.figure() + ax = fig.add_subplot(projection='3d') + + start1 = [1, 2, 3] + end1 = [4, 5, 6] + ax.arrow3d(end1, start1, + arrowstyle="->, head_length=0.6, head_width=0.3", color='red') + + start2 = [2, 5, 7] + end2 = [4, 6, -8] + ax.arrow3d(end2, start2, color='violet', ls='--', lw=2)