diff --git a/doc/users/whats_new/streamplot_set_maximum_length.rst b/doc/users/whats_new/streamplot_set_maximum_length.rst new file mode 100644 index 000000000000..434eb9ec9ecb --- /dev/null +++ b/doc/users/whats_new/streamplot_set_maximum_length.rst @@ -0,0 +1,5 @@ +Maximum streamline length and integration direction can now be specified +------------------------------------------------------------------------ + +This allows to follow the vector field for a longer time and can enhance the +visibility of the flow pattern in some use cases. diff --git a/lib/matplotlib/axes/_axes.py b/lib/matplotlib/axes/_axes.py index c9126c31f59a..60006a7ed545 100644 --- a/lib/matplotlib/axes/_axes.py +++ b/lib/matplotlib/axes/_axes.py @@ -4606,21 +4606,25 @@ def stackplot(self, x, *args, **kwargs): def streamplot(self, x, y, u, v, density=1, linewidth=None, color=None, cmap=None, norm=None, arrowsize=1, arrowstyle='-|>', minlength=0.1, transform=None, zorder=None, - start_points=None): + start_points=None, maxlength=4.0, + integration_direction='both'): if not self._hold: self.cla() - stream_container = mstream.streamplot(self, x, y, u, v, - density=density, - linewidth=linewidth, - color=color, - cmap=cmap, - norm=norm, - arrowsize=arrowsize, - arrowstyle=arrowstyle, - minlength=minlength, - start_points=start_points, - transform=transform, - zorder=zorder) + stream_container = mstream.streamplot( + self, x, y, u, v, + density=density, + linewidth=linewidth, + color=color, + cmap=cmap, + norm=norm, + arrowsize=arrowsize, + arrowstyle=arrowstyle, + minlength=minlength, + start_points=start_points, + transform=transform, + zorder=zorder, + maxlength=maxlength, + integration_direction=integration_direction) return stream_container streamplot.__doc__ = mstream.streamplot.__doc__ diff --git a/lib/matplotlib/pyplot.py b/lib/matplotlib/pyplot.py index b9782b7e3044..37bba9fcc42f 100644 --- a/lib/matplotlib/pyplot.py +++ b/lib/matplotlib/pyplot.py @@ -3288,8 +3288,8 @@ def step(x, y, *args, **kwargs): @_autogen_docstring(Axes.streamplot) def streamplot(x, y, u, v, density=1, linewidth=None, color=None, cmap=None, norm=None, arrowsize=1, arrowstyle='-|>', minlength=0.1, - transform=None, zorder=None, start_points=None, hold=None, - data=None): + transform=None, zorder=None, start_points=None, maxlength=4.0, + integration_direction='both', hold=None, data=None): ax = gca() # allow callers to override the hold state by passing hold=True|False washold = ax.ishold() @@ -3301,7 +3301,10 @@ def streamplot(x, y, u, v, density=1, linewidth=None, color=None, cmap=None, color=color, cmap=cmap, norm=norm, arrowsize=arrowsize, arrowstyle=arrowstyle, minlength=minlength, transform=transform, - zorder=zorder, start_points=start_points, data=data) + zorder=zorder, start_points=start_points, + maxlength=maxlength, + integration_direction=integration_direction, + data=data) finally: ax.hold(washold) sci(ret.lines) diff --git a/lib/matplotlib/streamplot.py b/lib/matplotlib/streamplot.py index cbc413b2312e..ccf068739cd9 100644 --- a/lib/matplotlib/streamplot.py +++ b/lib/matplotlib/streamplot.py @@ -22,7 +22,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, cmap=None, norm=None, arrowsize=1, arrowstyle='-|>', - minlength=0.1, transform=None, zorder=None, start_points=None): + minlength=0.1, transform=None, zorder=None, start_points=None, + maxlength=4.0, integration_direction='both'): """Draws streamlines of a vector flow. *x*, *y* : 1d arrays @@ -58,6 +59,10 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, In data coordinates, the same as the ``x`` and ``y`` arrays. *zorder* : int any number + *maxlength* : float + Maximum length of streamline in axes coordinates. + *integration_direction* : ['forward', 'backward', 'both'] + Integrate the streamline in forward, backward or both directions. Returns: @@ -95,6 +100,15 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, line_kw = {} arrow_kw = dict(arrowstyle=arrowstyle, mutation_scale=10 * arrowsize) + if integration_direction not in ['both', 'forward', 'backward']: + errstr = ("Integration direction '%s' not recognised. " + "Expected 'both', 'forward' or 'backward'." % + integration_direction) + raise ValueError(errstr) + + if integration_direction == 'both': + maxlength /= 2. + use_multicolor_lines = isinstance(color, np.ndarray) if use_multicolor_lines: if color.shape != grid.shape: @@ -126,7 +140,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None, u = np.ma.masked_invalid(u) v = np.ma.masked_invalid(v) - integrate = get_integrator(u, v, dmap, minlength) + integrate = get_integrator(u, v, dmap, minlength, maxlength, + integration_direction) trajectories = [] if start_points is None: @@ -401,7 +416,7 @@ class TerminateTrajectory(Exception): # Integrator definitions #======================== -def get_integrator(u, v, dmap, minlength): +def get_integrator(u, v, dmap, minlength, maxlength, integration_direction): # rescale velocity onto grid-coordinates for integrations. u, v = dmap.data2grid(u, v) @@ -435,17 +450,27 @@ def integrate(x0, y0): resulting trajectory is None if it is shorter than `minlength`. """ + stotal, x_traj, y_traj = 0., [], [] + try: dmap.start_trajectory(x0, y0) except InvalidIndexError: return None - sf, xf_traj, yf_traj = _integrate_rk12(x0, y0, dmap, forward_time) - dmap.reset_start_point(x0, y0) - sb, xb_traj, yb_traj = _integrate_rk12(x0, y0, dmap, backward_time) - # combine forward and backward trajectories - stotal = sf + sb - x_traj = xb_traj[::-1] + xf_traj[1:] - y_traj = yb_traj[::-1] + yf_traj[1:] + if integration_direction in ['both', 'backward']: + s, xt, yt = _integrate_rk12(x0, y0, dmap, backward_time, maxlength) + stotal += s + x_traj += xt[::-1] + y_traj += yt[::-1] + + if integration_direction in ['both', 'forward']: + dmap.reset_start_point(x0, y0) + s, xt, yt = _integrate_rk12(x0, y0, dmap, forward_time, maxlength) + if len(x_traj) > 0: + xt = xt[1:] + yt = yt[1:] + stotal += s + x_traj += xt + y_traj += yt if stotal > minlength: return x_traj, y_traj @@ -456,7 +481,7 @@ def integrate(x0, y0): return integrate -def _integrate_rk12(x0, y0, dmap, f): +def _integrate_rk12(x0, y0, dmap, f, maxlength): """2nd-order Runge-Kutta algorithm with adaptive step size. This method is also referred to as the improved Euler's method, or Heun's @@ -532,7 +557,7 @@ def _integrate_rk12(x0, y0, dmap, f): dmap.update_trajectory(xi, yi) except InvalidIndexError: break - if (stotal + ds) > 2: + if (stotal + ds) > maxlength: break stotal += ds diff --git a/lib/matplotlib/tests/baseline_images/test_streamplot/streamplot_direction.png b/lib/matplotlib/tests/baseline_images/test_streamplot/streamplot_direction.png new file mode 100644 index 000000000000..77786b3e4875 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_streamplot/streamplot_direction.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_streamplot/streamplot_maxlength.png b/lib/matplotlib/tests/baseline_images/test_streamplot/streamplot_maxlength.png new file mode 100644 index 000000000000..2ccb71c0581f Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_streamplot/streamplot_maxlength.png differ diff --git a/lib/matplotlib/tests/test_streamplot.py b/lib/matplotlib/tests/test_streamplot.py index 2f84121fac51..782a15b13713 100644 --- a/lib/matplotlib/tests/test_streamplot.py +++ b/lib/matplotlib/tests/test_streamplot.py @@ -16,6 +16,14 @@ def velocity_field(): V = 1 + X - Y**2 return X, Y, U, V +def swirl_velocity_field(): + x = np.linspace(-3., 3., 100) + y = np.linspace(-3., 3., 100) + X, Y = np.meshgrid(x, y) + a = 0.1 + U = np.cos(a) * (-Y) - np.sin(a) * X + V = np.sin(a) * (-Y) + np.cos(a) * X + return x, y, U, V @image_comparison(baseline_images=['streamplot_startpoints']) def test_startpoints(): @@ -57,6 +65,23 @@ def test_masks_and_nans(): plt.streamplot(X, Y, U, V, color=U, cmap=plt.cm.Blues) +@image_comparison(baseline_images=['streamplot_maxlength'], + extensions=['png']) +def test_maxlength(): + x, y, U, V = swirl_velocity_field() + plt.streamplot(x, y, U, V, maxlength=10., start_points=[[0., 1.5]], + linewidth=2, density=2) + + +@image_comparison(baseline_images=['streamplot_direction'], + extensions=['png']) +def test_direction(): + x, y, U, V = swirl_velocity_field() + plt.streamplot(x, y, U, V, integration_direction='backward', + maxlength=1.5, start_points=[[1.5, 0.]], + linewidth=2, density=2) + + @cleanup def test_streamplot_limits(): ax = plt.axes()