Skip to content

Scatter: make "c" and "s" argument handling more consistent. #13959

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 21 additions & 28 deletions lib/matplotlib/axes/_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4133,7 +4133,7 @@ def dopatch(xs, ys, **kwargs):
medians=medians, fliers=fliers, means=means)

@staticmethod
def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
def _parse_scatter_color_args(c, edgecolors, kwargs, xsize,
get_next_color_func):
"""
Helper function to process color related arguments of `.Axes.scatter`.
Expand Down Expand Up @@ -4163,8 +4163,8 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
Additional kwargs. If these keys exist, we pop and process them:
'facecolors', 'facecolor', 'edgecolor', 'color'
Note: The dict is modified by this function.
xshape, yshape : tuple of int
The shape of the x and y arrays passed to `.Axes.scatter`.
xsize : int
The size of the x and y arrays passed to `.Axes.scatter`.
get_next_color_func : callable
A callable that returns a color. This color is used as facecolor
if no other color is provided.
Expand All @@ -4187,9 +4187,6 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
The edgecolor specification.

"""
xsize = functools.reduce(operator.mul, xshape, 1)
ysize = functools.reduce(operator.mul, yshape, 1)

facecolors = kwargs.pop('facecolors', None)
facecolors = kwargs.pop('facecolor', facecolors)
edgecolors = kwargs.pop('edgecolor', edgecolors)
Expand Down Expand Up @@ -4229,7 +4226,7 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
# favor of mapping, not rgb or rgba.
# Convenience vars to track shape mismatch *and* conversion failures.
valid_shape = True # will be put to the test!
n_elem = -1 # used only for (some) exceptions
csize = -1 # Number of colors; used for some exceptions.

if (c_was_none or
kwcolor is not None or
Expand All @@ -4241,9 +4238,9 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
else:
try: # First, does 'c' look suitable for value-mapping?
c_array = np.asanyarray(c, dtype=float)
n_elem = c_array.shape[0]
if c_array.shape in [xshape, yshape]:
c = np.ma.ravel(c_array)
csize = c_array.size
if csize == xsize:
c = c_array.ravel()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did we need this cast to masked here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so. The method will work regardless, since c_array is some kind of array, and I don't think we are later depending on c being a masked array.

else:
if c_array.shape in ((3,), (4,)):
_log.warning(
Expand All @@ -4262,28 +4259,23 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
if c_array is None:
try: # Then is 'c' acceptable as PathCollection facecolors?
colors = mcolors.to_rgba_array(c)
n_elem = colors.shape[0]
if colors.shape[0] not in (0, 1, xsize, ysize):
csize = colors.shape[0]
if csize not in (0, 1, xsize):
# NB: remember that a single color is also acceptable.
# Besides *colors* will be an empty array if c == 'none'.
valid_shape = False
raise ValueError
except ValueError:
if not valid_shape: # but at least one conversion succeeded.
raise ValueError(
"'c' argument has {nc} elements, which is not "
"acceptable for use with 'x' with size {xs}, "
"'y' with size {ys}."
.format(nc=n_elem, xs=xsize, ys=ysize)
)
f"'c' argument has {csize} elements, which is "
"inconsistent with 'x' and 'y' with size {xsize}.")
else:
# Both the mapping *and* the RGBA conversion failed: pretty
# severe failure => one may appreciate a verbose feedback.
raise ValueError(
"'c' argument must be a mpl color, a sequence of mpl "
"colors or a sequence of numbers, not {}."
.format(c) # note: could be long depending on c
)
f"'c' argument must be a mpl color, a sequence of mpl "
"colors, or a sequence of numbers, not {c}.")
else:
colors = None # use cmap, norm after collection is created
return c, colors, edgecolors
Expand All @@ -4301,7 +4293,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,

Parameters
----------
x, y : array_like, shape (n, )
x, y : scalar or array_like, shape (n, )
The data positions.

s : scalar or array_like, shape (n, ), optional
Expand All @@ -4313,8 +4305,8 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,

- A single color format string.
- A sequence of color specifications of length n.
- A sequence of n numbers to be mapped to colors using *cmap* and
*norm*.
- A scalar or sequence of n numbers to be mapped to colors using
*cmap* and *norm*.
- A 2-D array in which the rows are RGB or RGBA.

Note that *c* should not be a single numeric RGB or RGBA sequence
Expand Down Expand Up @@ -4403,7 +4395,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
plotted.

* Fundamentally, scatter works with 1-D arrays; *x*, *y*, *s*, and *c*
may be input as 2-D arrays, but within scatter they will be
may be input as N-D arrays, but within scatter they will be
flattened. The exception is *c*, which will be flattened only if its
size matches the size of *x* and *y*.

Expand All @@ -4416,7 +4408,6 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,

# np.ma.ravel yields an ndarray, not a masked array,
# unless its argument is a masked array.
xshape, yshape = np.shape(x), np.shape(y)
x = np.ma.ravel(x)
y = np.ma.ravel(y)
if x.size != y.size:
Expand All @@ -4425,11 +4416,13 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
if s is None:
s = (20 if rcParams['_internal.classic_mode'] else
rcParams['lines.markersize'] ** 2.0)
s = np.ma.ravel(s) # This doesn't have to match x, y in size.
s = np.ma.ravel(s)
if len(s) not in (1, x.size):
raise ValueError("s must be a scalar, or the same size as x and y")

c, colors, edgecolors = \
self._parse_scatter_color_args(
c, edgecolors, kwargs, xshape, yshape,
c, edgecolors, kwargs, x.size,
get_next_color_func=self._get_patches_for_fill.get_next_color)

if plotnonfinite and colors is None:
Expand Down
38 changes: 28 additions & 10 deletions lib/matplotlib/tests/test_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1862,6 +1862,13 @@ def test_scatter_color(self):
with pytest.raises(ValueError):
plt.scatter([1, 2, 3], [1, 2, 3], color=[1, 2, 3])

def test_scatter_size_arg_size(self):
x = np.arange(4)
with pytest.raises(ValueError):
plt.scatter(x, x, x[1:])
with pytest.raises(ValueError):
plt.scatter(x[1:], x[1:], x)

@check_figures_equal(extensions=["png"])
def test_scatter_invalid_color(self, fig_test, fig_ref):
ax = fig_test.subplots()
Expand Down Expand Up @@ -1890,6 +1897,21 @@ def test_scatter_no_invalid_color(self, fig_test, fig_ref):
ax = fig_ref.subplots()
ax.scatter([0, 2], [0, 2], c=[1, 2], s=[1, 3], cmap=cmap)

@check_figures_equal(extensions=["png"])
def test_scatter_single_point(self, fig_test, fig_ref):
ax = fig_test.subplots()
ax.scatter(1, 1, c=1)
ax = fig_ref.subplots()
ax.scatter([1], [1], c=[1])

@check_figures_equal(extensions=["png"])
def test_scatter_different_shapes(self, fig_test, fig_ref):
x = np.arange(10)
ax = fig_test.subplots()
ax.scatter(x, x.reshape(2, 5), c=x.reshape(5, 2))
ax = fig_ref.subplots()
ax.scatter(x.reshape(5, 2), x, c=x.reshape(2, 5))

# Parameters for *test_scatter_c*. NB: assuming that the
# scatter plot will have 4 elements. The tuple scheme is:
# (*c* parameter case, exception regexp key or None if no exception)
Expand Down Expand Up @@ -1946,7 +1968,7 @@ def get_next_color():

from matplotlib.axes import Axes

xshape = yshape = (4,)
xsize = 4

# Additional checking of *c* (introduced in #11383).
REGEXP = {
Expand All @@ -1956,21 +1978,18 @@ def get_next_color():

if re_key is None:
Axes._parse_scatter_color_args(
c=c_case, edgecolors="black", kwargs={},
xshape=xshape, yshape=yshape,
c=c_case, edgecolors="black", kwargs={}, xsize=xsize,
get_next_color_func=get_next_color)
else:
with pytest.raises(ValueError, match=REGEXP[re_key]):
Axes._parse_scatter_color_args(
c=c_case, edgecolors="black", kwargs={},
xshape=xshape, yshape=yshape,
c=c_case, edgecolors="black", kwargs={}, xsize=xsize,
get_next_color_func=get_next_color)


def _params(c=None, xshape=(2,), yshape=(2,), **kwargs):
def _params(c=None, xsize=2, **kwargs):
edgecolors = kwargs.pop('edgecolors', None)
return (c, edgecolors, kwargs if kwargs is not None else {},
xshape, yshape)
return (c, edgecolors, kwargs if kwargs is not None else {}, xsize)
_result = namedtuple('_result', 'c, colors')


Expand Down Expand Up @@ -2022,8 +2041,7 @@ def get_next_color():
c = kwargs.pop('c', None)
edgecolors = kwargs.pop('edgecolors', None)
_, _, result_edgecolors = \
Axes._parse_scatter_color_args(c, edgecolors, kwargs,
xshape=(2,), yshape=(2,),
Axes._parse_scatter_color_args(c, edgecolors, kwargs, xsize=2,
get_next_color_func=get_next_color)
assert result_edgecolors == expected_edgecolors

Expand Down