Skip to content

Various small fixes for streamplot(). #21558

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions lib/matplotlib/streamplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
if use_multicolor_lines:
if color.shape != grid.shape:
raise ValueError("If 'color' is given, it must match the shape of "
"'Grid(x, y)'")
line_colors = []
"the (x, y) grid")
line_colors = [[]] # Empty entry allows concatenation of zero arrays.
color = np.ma.masked_invalid(color)
else:
line_kw['color'] = color
Expand All @@ -126,7 +126,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
if isinstance(linewidth, np.ndarray):
if linewidth.shape != grid.shape:
raise ValueError("If 'linewidth' is given, it must match the "
"shape of 'Grid(x, y)'")
"shape of the (x, y) grid")
line_kw['linewidth'] = []
else:
line_kw['linewidth'] = linewidth
Expand All @@ -137,7 +137,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,

# Sanity checks.
if u.shape != grid.shape or v.shape != grid.shape:
raise ValueError("'u' and 'v' must match the shape of 'Grid(x, y)'")
raise ValueError("'u' and 'v' must match the shape of the (x, y) grid")

u = np.ma.masked_invalid(u)
v = np.ma.masked_invalid(v)
Expand Down Expand Up @@ -310,21 +310,22 @@ class Grid:
"""Grid of data."""
def __init__(self, x, y):

if x.ndim == 1:
if np.ndim(x) == 1:
pass
elif x.ndim == 2:
x_row = x[0, :]
elif np.ndim(x) == 2:
x_row = x[0]
if not np.allclose(x_row, x):
raise ValueError("The rows of 'x' must be equal")
x = x_row
else:
raise ValueError("'x' can have at maximum 2 dimensions")

if y.ndim == 1:
if np.ndim(y) == 1:
pass
elif y.ndim == 2:
y_col = y[:, 0]
if not np.allclose(y_col, y.T):
elif np.ndim(y) == 2:
yt = np.transpose(y) # Also works for nested lists.
y_col = yt[0]
if not np.allclose(y_col, yt):
raise ValueError("The columns of 'y' must be equal")
y = y_col
else:
Expand Down
10 changes: 10 additions & 0 deletions lib/matplotlib/tests/test_streamplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,13 @@ def test_streamplot_grid():

with pytest.raises(ValueError, match="'y' must be strictly increasing"):
plt.streamplot(x, y, u, v)


def test_streamplot_inputs(): # test no exception occurs.
# fully-masked
plt.streamplot(np.arange(3), np.arange(3),
np.full((3, 3), np.nan), np.full((3, 3), np.nan),
color=np.random.rand(3, 3))
# array-likes
plt.streamplot(range(3), range(3),
np.random.rand(3, 3), np.random.rand(3, 3))