Skip to content

Commit e3c9376

Browse files
authored
Merge pull request #13959 from efiring/scatter_ravel_consistency
Scatter: make "c" and "s" argument handling more consistent.
2 parents 71771ea + c1add44 commit e3c9376

File tree

2 files changed

+49
-38
lines changed

2 files changed

+49
-38
lines changed

lib/matplotlib/axes/_axes.py

+21-28
Original file line numberDiff line numberDiff line change
@@ -4138,7 +4138,7 @@ def dopatch(xs, ys, **kwargs):
41384138
medians=medians, fliers=fliers, means=means)
41394139

41404140
@staticmethod
4141-
def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
4141+
def _parse_scatter_color_args(c, edgecolors, kwargs, xsize,
41424142
get_next_color_func):
41434143
"""
41444144
Helper function to process color related arguments of `.Axes.scatter`.
@@ -4168,8 +4168,8 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
41684168
Additional kwargs. If these keys exist, we pop and process them:
41694169
'facecolors', 'facecolor', 'edgecolor', 'color'
41704170
Note: The dict is modified by this function.
4171-
xshape, yshape : tuple of int
4172-
The shape of the x and y arrays passed to `.Axes.scatter`.
4171+
xsize : int
4172+
The size of the x and y arrays passed to `.Axes.scatter`.
41734173
get_next_color_func : callable
41744174
A callable that returns a color. This color is used as facecolor
41754175
if no other color is provided.
@@ -4192,9 +4192,6 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
41924192
The edgecolor specification.
41934193
41944194
"""
4195-
xsize = functools.reduce(operator.mul, xshape, 1)
4196-
ysize = functools.reduce(operator.mul, yshape, 1)
4197-
41984195
facecolors = kwargs.pop('facecolors', None)
41994196
facecolors = kwargs.pop('facecolor', facecolors)
42004197
edgecolors = kwargs.pop('edgecolor', edgecolors)
@@ -4234,7 +4231,7 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
42344231
# favor of mapping, not rgb or rgba.
42354232
# Convenience vars to track shape mismatch *and* conversion failures.
42364233
valid_shape = True # will be put to the test!
4237-
n_elem = -1 # used only for (some) exceptions
4234+
csize = -1 # Number of colors; used for some exceptions.
42384235

42394236
if (c_was_none or
42404237
kwcolor is not None or
@@ -4246,9 +4243,9 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
42464243
else:
42474244
try: # First, does 'c' look suitable for value-mapping?
42484245
c_array = np.asanyarray(c, dtype=float)
4249-
n_elem = c_array.shape[0]
4250-
if c_array.shape in [xshape, yshape]:
4251-
c = np.ma.ravel(c_array)
4246+
csize = c_array.size
4247+
if csize == xsize:
4248+
c = c_array.ravel()
42524249
else:
42534250
if c_array.shape in ((3,), (4,)):
42544251
_log.warning(
@@ -4267,28 +4264,23 @@ def _parse_scatter_color_args(c, edgecolors, kwargs, xshape, yshape,
42674264
if c_array is None:
42684265
try: # Then is 'c' acceptable as PathCollection facecolors?
42694266
colors = mcolors.to_rgba_array(c)
4270-
n_elem = colors.shape[0]
4271-
if colors.shape[0] not in (0, 1, xsize, ysize):
4267+
csize = colors.shape[0]
4268+
if csize not in (0, 1, xsize):
42724269
# NB: remember that a single color is also acceptable.
42734270
# Besides *colors* will be an empty array if c == 'none'.
42744271
valid_shape = False
42754272
raise ValueError
42764273
except ValueError:
42774274
if not valid_shape: # but at least one conversion succeeded.
42784275
raise ValueError(
4279-
"'c' argument has {nc} elements, which is not "
4280-
"acceptable for use with 'x' with size {xs}, "
4281-
"'y' with size {ys}."
4282-
.format(nc=n_elem, xs=xsize, ys=ysize)
4283-
)
4276+
f"'c' argument has {csize} elements, which is "
4277+
"inconsistent with 'x' and 'y' with size {xsize}.")
42844278
else:
42854279
# Both the mapping *and* the RGBA conversion failed: pretty
42864280
# severe failure => one may appreciate a verbose feedback.
42874281
raise ValueError(
4288-
"'c' argument must be a mpl color, a sequence of mpl "
4289-
"colors or a sequence of numbers, not {}."
4290-
.format(c) # note: could be long depending on c
4291-
)
4282+
f"'c' argument must be a mpl color, a sequence of mpl "
4283+
"colors, or a sequence of numbers, not {c}.")
42924284
else:
42934285
colors = None # use cmap, norm after collection is created
42944286
return c, colors, edgecolors
@@ -4306,7 +4298,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43064298
43074299
Parameters
43084300
----------
4309-
x, y : array_like, shape (n, )
4301+
x, y : scalar or array_like, shape (n, )
43104302
The data positions.
43114303
43124304
s : scalar or array_like, shape (n, ), optional
@@ -4318,8 +4310,8 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
43184310
43194311
- A single color format string.
43204312
- A sequence of color specifications of length n.
4321-
- A sequence of n numbers to be mapped to colors using *cmap* and
4322-
*norm*.
4313+
- A scalar or sequence of n numbers to be mapped to colors using
4314+
*cmap* and *norm*.
43234315
- A 2-D array in which the rows are RGB or RGBA.
43244316
43254317
Note that *c* should not be a single numeric RGB or RGBA sequence
@@ -4408,7 +4400,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
44084400
plotted.
44094401
44104402
* Fundamentally, scatter works with 1-D arrays; *x*, *y*, *s*, and *c*
4411-
may be input as 2-D arrays, but within scatter they will be
4403+
may be input as N-D arrays, but within scatter they will be
44124404
flattened. The exception is *c*, which will be flattened only if its
44134405
size matches the size of *x* and *y*.
44144406
@@ -4421,7 +4413,6 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
44214413

44224414
# np.ma.ravel yields an ndarray, not a masked array,
44234415
# unless its argument is a masked array.
4424-
xshape, yshape = np.shape(x), np.shape(y)
44254416
x = np.ma.ravel(x)
44264417
y = np.ma.ravel(y)
44274418
if x.size != y.size:
@@ -4430,11 +4421,13 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
44304421
if s is None:
44314422
s = (20 if rcParams['_internal.classic_mode'] else
44324423
rcParams['lines.markersize'] ** 2.0)
4433-
s = np.ma.ravel(s) # This doesn't have to match x, y in size.
4424+
s = np.ma.ravel(s)
4425+
if len(s) not in (1, x.size):
4426+
raise ValueError("s must be a scalar, or the same size as x and y")
44344427

44354428
c, colors, edgecolors = \
44364429
self._parse_scatter_color_args(
4437-
c, edgecolors, kwargs, xshape, yshape,
4430+
c, edgecolors, kwargs, x.size,
44384431
get_next_color_func=self._get_patches_for_fill.get_next_color)
44394432

44404433
if plotnonfinite and colors is None:

lib/matplotlib/tests/test_axes.py

+28-10
Original file line numberDiff line numberDiff line change
@@ -1820,6 +1820,13 @@ def test_scatter_color(self):
18201820
with pytest.raises(ValueError):
18211821
plt.scatter([1, 2, 3], [1, 2, 3], color=[1, 2, 3])
18221822

1823+
def test_scatter_size_arg_size(self):
1824+
x = np.arange(4)
1825+
with pytest.raises(ValueError):
1826+
plt.scatter(x, x, x[1:])
1827+
with pytest.raises(ValueError):
1828+
plt.scatter(x[1:], x[1:], x)
1829+
18231830
@check_figures_equal(extensions=["png"])
18241831
def test_scatter_invalid_color(self, fig_test, fig_ref):
18251832
ax = fig_test.subplots()
@@ -1848,6 +1855,21 @@ def test_scatter_no_invalid_color(self, fig_test, fig_ref):
18481855
ax = fig_ref.subplots()
18491856
ax.scatter([0, 2], [0, 2], c=[1, 2], s=[1, 3], cmap=cmap)
18501857

1858+
@check_figures_equal(extensions=["png"])
1859+
def test_scatter_single_point(self, fig_test, fig_ref):
1860+
ax = fig_test.subplots()
1861+
ax.scatter(1, 1, c=1)
1862+
ax = fig_ref.subplots()
1863+
ax.scatter([1], [1], c=[1])
1864+
1865+
@check_figures_equal(extensions=["png"])
1866+
def test_scatter_different_shapes(self, fig_test, fig_ref):
1867+
x = np.arange(10)
1868+
ax = fig_test.subplots()
1869+
ax.scatter(x, x.reshape(2, 5), c=x.reshape(5, 2))
1870+
ax = fig_ref.subplots()
1871+
ax.scatter(x.reshape(5, 2), x, c=x.reshape(2, 5))
1872+
18511873
# Parameters for *test_scatter_c*. NB: assuming that the
18521874
# scatter plot will have 4 elements. The tuple scheme is:
18531875
# (*c* parameter case, exception regexp key or None if no exception)
@@ -1904,7 +1926,7 @@ def get_next_color():
19041926

19051927
from matplotlib.axes import Axes
19061928

1907-
xshape = yshape = (4,)
1929+
xsize = 4
19081930

19091931
# Additional checking of *c* (introduced in #11383).
19101932
REGEXP = {
@@ -1914,21 +1936,18 @@ def get_next_color():
19141936

19151937
if re_key is None:
19161938
Axes._parse_scatter_color_args(
1917-
c=c_case, edgecolors="black", kwargs={},
1918-
xshape=xshape, yshape=yshape,
1939+
c=c_case, edgecolors="black", kwargs={}, xsize=xsize,
19191940
get_next_color_func=get_next_color)
19201941
else:
19211942
with pytest.raises(ValueError, match=REGEXP[re_key]):
19221943
Axes._parse_scatter_color_args(
1923-
c=c_case, edgecolors="black", kwargs={},
1924-
xshape=xshape, yshape=yshape,
1944+
c=c_case, edgecolors="black", kwargs={}, xsize=xsize,
19251945
get_next_color_func=get_next_color)
19261946

19271947

1928-
def _params(c=None, xshape=(2,), yshape=(2,), **kwargs):
1948+
def _params(c=None, xsize=2, **kwargs):
19291949
edgecolors = kwargs.pop('edgecolors', None)
1930-
return (c, edgecolors, kwargs if kwargs is not None else {},
1931-
xshape, yshape)
1950+
return (c, edgecolors, kwargs if kwargs is not None else {}, xsize)
19321951
_result = namedtuple('_result', 'c, colors')
19331952

19341953

@@ -1980,8 +1999,7 @@ def get_next_color():
19801999
c = kwargs.pop('c', None)
19812000
edgecolors = kwargs.pop('edgecolors', None)
19822001
_, _, result_edgecolors = \
1983-
Axes._parse_scatter_color_args(c, edgecolors, kwargs,
1984-
xshape=(2,), yshape=(2,),
2002+
Axes._parse_scatter_color_args(c, edgecolors, kwargs, xsize=2,
19852003
get_next_color_func=get_next_color)
19862004
assert result_edgecolors == expected_edgecolors
19872005

0 commit comments

Comments
 (0)