diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index b1032f45091e..e4ff16cf9410 100755 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -2209,7 +2209,7 @@ def add_collection3d(self, col, zs=0, zdir='z'): Axes.add_collection(self, col) - def scatter(self, xs, ys, zs=0, zdir='z', s=20, c='b', depthshade=True, + def scatter(self, xs, ys, zs=0, zdir='z', s=20, c=None, depthshade=True, *args, **kwargs): ''' Create a scatter plot. @@ -2233,7 +2233,9 @@ def scatter(self, xs, ys, zs=0, zdir='z', s=20, c='b', depthshade=True, that *c* should not be a single numeric RGB or RGBA sequence because that is indistinguishable from an array of values to be colormapped. *c* can be a 2-D array in - which the rows are RGB or RGBA, however. + which the rows are RGB or RGBA, however, including the + case of a single row to specify the same color for + all points. *depthshade* Whether or not to shade the scatter markers to give @@ -2262,13 +2264,15 @@ def scatter(self, xs, ys, zs=0, zdir='z', s=20, c='b', depthshade=True, s = np.ma.ravel(s) # This doesn't have to match x, y in size. - cstr = cbook.is_string_like(c) or cbook.is_sequence_of_strings(c) - if not cstr: - c = np.asanyarray(c) - if c.size == xs.size: - c = np.ma.ravel(c) - - xs, ys, zs, s, c = cbook.delete_masked_points(xs, ys, zs, s, c) + if c is not None: + cstr = cbook.is_string_like(c) or cbook.is_sequence_of_strings(c) + if not cstr: + c = np.asanyarray(c) + if c.size == xs.size: + c = np.ma.ravel(c) + xs, ys, zs, s, c = cbook.delete_masked_points(xs, ys, zs, s, c) + else: + xs, ys, zs, s = cbook.delete_masked_points(xs, ys, zs, s) patches = Axes.scatter(self, xs, ys, s=s, c=c, *args, **kwargs) if not cbook.iterable(zs): diff --git a/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/scatter3d_color.png b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/scatter3d_color.png new file mode 100644 index 000000000000..1ded1b338850 Binary files /dev/null and b/lib/mpl_toolkits/tests/baseline_images/test_mplot3d/scatter3d_color.png differ diff --git a/lib/mpl_toolkits/tests/test_mplot3d.py b/lib/mpl_toolkits/tests/test_mplot3d.py index 42718166d32d..5b8485f802b8 100644 --- a/lib/mpl_toolkits/tests/test_mplot3d.py +++ b/lib/mpl_toolkits/tests/test_mplot3d.py @@ -110,6 +110,17 @@ def test_scatter3d(): c='b', marker='^') +@image_comparison(baseline_images=['scatter3d_color'], remove_text=True, + extensions=['png']) +def test_scatter3d_color(): + fig = plt.figure() + ax = fig.add_subplot(111, projection='3d') + ax.scatter(np.arange(10), np.arange(10), np.arange(10), + color='r', marker='o') + ax.scatter(np.arange(10, 20), np.arange(10, 20), np.arange(10, 20), + color='b', marker='s') + + @image_comparison(baseline_images=['surface3d'], remove_text=True) def test_surface3d(): fig = plt.figure()