diff --git a/lib/matplotlib/streamplot.py b/lib/matplotlib/streamplot.py index f076209e5dc9..a22d6824aea9 100644 --- a/lib/matplotlib/streamplot.py +++ b/lib/matplotlib/streamplot.py @@ -116,8 +116,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, if use_multicolor_lines: if color.shape != grid.shape: raise ValueError("If 'color' is given, it must match the shape of " - "'Grid(x, y)'") - line_colors = [] + "the (x, y) grid") + line_colors = [[]] # Empty entry allows concatenation of zero arrays. color = np.ma.masked_invalid(color) else: line_kw['color'] = color @@ -126,7 +126,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, if isinstance(linewidth, np.ndarray): if linewidth.shape != grid.shape: raise ValueError("If 'linewidth' is given, it must match the " - "shape of 'Grid(x, y)'") + "shape of the (x, y) grid") line_kw['linewidth'] = [] else: line_kw['linewidth'] = linewidth @@ -137,7 +137,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, # Sanity checks. if u.shape != grid.shape or v.shape != grid.shape: - raise ValueError("'u' and 'v' must match the shape of 'Grid(x, y)'") + raise ValueError("'u' and 'v' must match the shape of the (x, y) grid") u = np.ma.masked_invalid(u) v = np.ma.masked_invalid(v) @@ -310,21 +310,22 @@ class Grid: """Grid of data.""" def __init__(self, x, y): - if x.ndim == 1: + if np.ndim(x) == 1: pass - elif x.ndim == 2: - x_row = x[0, :] + elif np.ndim(x) == 2: + x_row = x[0] if not np.allclose(x_row, x): raise ValueError("The rows of 'x' must be equal") x = x_row else: raise ValueError("'x' can have at maximum 2 dimensions") - if y.ndim == 1: + if np.ndim(y) == 1: pass - elif y.ndim == 2: - y_col = y[:, 0] - if not np.allclose(y_col, y.T): + elif np.ndim(y) == 2: + yt = np.transpose(y) # Also works for nested lists. + y_col = yt[0] + if not np.allclose(y_col, yt): raise ValueError("The columns of 'y' must be equal") y = y_col else: diff --git a/lib/matplotlib/tests/test_streamplot.py b/lib/matplotlib/tests/test_streamplot.py index 88c3ec2768e9..3708d862dac7 100644 --- a/lib/matplotlib/tests/test_streamplot.py +++ b/lib/matplotlib/tests/test_streamplot.py @@ -144,3 +144,13 @@ def test_streamplot_grid(): with pytest.raises(ValueError, match="'y' must be strictly increasing"): plt.streamplot(x, y, u, v) + + +def test_streamplot_inputs(): # test no exception occurs. + # fully-masked + plt.streamplot(np.arange(3), np.arange(3), + np.full((3, 3), np.nan), np.full((3, 3), np.nan), + color=np.random.rand(3, 3)) + # array-likes + plt.streamplot(range(3), range(3), + np.random.rand(3, 3), np.random.rand(3, 3))