diff --git a/lib/matplotlib/streamplot.py b/lib/matplotlib/streamplot.py index 76a9d61aec73..1bfacdb70339 100644 --- a/lib/matplotlib/streamplot.py +++ b/lib/matplotlib/streamplot.py @@ -92,8 +92,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, if use_multicolor_lines: assert color.shape == grid.shape line_colors = [] - if np.any(np.isnan(color)): - color = np.ma.array(color, mask=np.isnan(color)) + color = np.ma.masked_invalid(color) else: line_kw['color'] = color arrow_kw['color'] = color @@ -112,10 +111,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, assert u.shape == grid.shape assert v.shape == grid.shape - if np.any(np.isnan(u)): - u = np.ma.array(u, mask=np.isnan(u)) - if np.any(np.isnan(v)): - v = np.ma.array(v, mask=np.isnan(v)) + u = np.ma.masked_invalid(u) + v = np.ma.masked_invalid(v) integrate = get_integrator(u, v, dmap, minlength) @@ -160,7 +157,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, if use_multicolor_lines: color_values = interpgrid(color, tgx, tgy)[:-1] - line_colors.extend(color_values) + line_colors.append(color_values) arrow_kw['color'] = cmap(norm(color_values[n])) p = patches.FancyArrowPatch(arrow_tail, @@ -174,7 +171,7 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, transform=transform, **line_kw) if use_multicolor_lines: - lc.set_array(np.asarray(line_colors)) + lc.set_array(np.ma.hstack(line_colors)) lc.set_cmap(cmap) lc.set_norm(norm) axes.add_collection(lc)