diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index d4a2699b93e4..a7241b427a35 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -1438,7 +1438,7 @@ def plot(self, xs, ys, *args, zdir='z', **kwargs): zs = kwargs.pop('zs', 0) # Match length - zs = np.broadcast_to(zs, len(xs)) + zs = np.broadcast_to(zs, np.shape(xs)) lines = super().plot(xs, ys, *args, **kwargs) for line in lines: diff --git a/lib/mpl_toolkits/tests/test_mplot3d.py b/lib/mpl_toolkits/tests/test_mplot3d.py index 592096b740f9..b79d94530fc2 100644 --- a/lib/mpl_toolkits/tests/test_mplot3d.py +++ b/lib/mpl_toolkits/tests/test_mplot3d.py @@ -160,6 +160,14 @@ def test_lines3d(): ax.plot(x, y, z) +@check_figures_equal(extensions=["png"]) +def test_plot_scalar(fig_test, fig_ref): + ax1 = fig_test.gca(projection='3d') + ax1.plot([1], [1], "o") + ax2 = fig_ref.gca(projection='3d') + ax2.plot(1, 1, "o") + + @image_comparison(['mixedsubplot.png'], remove_text=True) def test_mixedsubplots(): def f(t):