22
22
23
23
def streamplot (axes , x , y , u , v , density = 1 , linewidth = None , color = None ,
24
24
cmap = None , norm = None , arrowsize = 1 , arrowstyle = '-|>' ,
25
- minlength = 0.1 , transform = None , zorder = None , start_points = None ):
25
+ minlength = 0.1 , transform = None , zorder = None , start_points = None ,
26
+ maxlength = 4.0 , integration_direction = 'both' ):
26
27
"""Draws streamlines of a vector flow.
27
28
28
29
*x*, *y* : 1d arrays
@@ -58,6 +59,10 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
58
59
In data coordinates, the same as the ``x`` and ``y`` arrays.
59
60
*zorder* : int
60
61
any number
62
+ *maxlength* : float
63
+ Maximum length of streamline in axes coordinates.
64
+ *integration_direction* : ['forward', 'backward', 'both']
65
+ Integrate the streamline in forward, backward or both directions.
61
66
62
67
Returns:
63
68
@@ -95,6 +100,15 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
95
100
line_kw = {}
96
101
arrow_kw = dict (arrowstyle = arrowstyle , mutation_scale = 10 * arrowsize )
97
102
103
+ if integration_direction not in ['both' , 'forward' , 'backward' ]:
104
+ errstr = ("Integration direction '%s' not recognised. "
105
+ "Expected 'both', 'forward' or 'backward'." %
106
+ integration_direction )
107
+ raise ValueError (errstr )
108
+
109
+ if integration_direction == 'both' :
110
+ maxlength /= 2.
111
+
98
112
use_multicolor_lines = isinstance (color , np .ndarray )
99
113
if use_multicolor_lines :
100
114
if color .shape != grid .shape :
@@ -126,7 +140,8 @@ def streamplot(axes, x, y, u, v, density=1, linewidth=None, color=None,
126
140
u = np .ma .masked_invalid (u )
127
141
v = np .ma .masked_invalid (v )
128
142
129
- integrate = get_integrator (u , v , dmap , minlength )
143
+ integrate = get_integrator (u , v , dmap , minlength , maxlength ,
144
+ integration_direction )
130
145
131
146
trajectories = []
132
147
if start_points is None :
@@ -401,7 +416,7 @@ class TerminateTrajectory(Exception):
401
416
# Integrator definitions
402
417
#========================
403
418
404
- def get_integrator (u , v , dmap , minlength ):
419
+ def get_integrator (u , v , dmap , minlength , maxlength , integration_direction ):
405
420
406
421
# rescale velocity onto grid-coordinates for integrations.
407
422
u , v = dmap .data2grid (u , v )
@@ -435,17 +450,27 @@ def integrate(x0, y0):
435
450
resulting trajectory is None if it is shorter than `minlength`.
436
451
"""
437
452
453
+ stotal , x_traj , y_traj = 0. , [], []
454
+
438
455
try :
439
456
dmap .start_trajectory (x0 , y0 )
440
457
except InvalidIndexError :
441
458
return None
442
- sf , xf_traj , yf_traj = _integrate_rk12 (x0 , y0 , dmap , forward_time )
443
- dmap .reset_start_point (x0 , y0 )
444
- sb , xb_traj , yb_traj = _integrate_rk12 (x0 , y0 , dmap , backward_time )
445
- # combine forward and backward trajectories
446
- stotal = sf + sb
447
- x_traj = xb_traj [::- 1 ] + xf_traj [1 :]
448
- y_traj = yb_traj [::- 1 ] + yf_traj [1 :]
459
+ if integration_direction in ['both' , 'backward' ]:
460
+ s , xt , yt = _integrate_rk12 (x0 , y0 , dmap , backward_time , maxlength )
461
+ stotal += s
462
+ x_traj += xt [::- 1 ]
463
+ y_traj += yt [::- 1 ]
464
+
465
+ if integration_direction in ['both' , 'forward' ]:
466
+ dmap .reset_start_point (x0 , y0 )
467
+ s , xt , yt = _integrate_rk12 (x0 , y0 , dmap , forward_time , maxlength )
468
+ if len (x_traj ) > 0 :
469
+ xt = xt [1 :]
470
+ yt = yt [1 :]
471
+ stotal += s
472
+ x_traj += xt
473
+ y_traj += yt
449
474
450
475
if stotal > minlength :
451
476
return x_traj , y_traj
@@ -456,7 +481,7 @@ def integrate(x0, y0):
456
481
return integrate
457
482
458
483
459
- def _integrate_rk12 (x0 , y0 , dmap , f ):
484
+ def _integrate_rk12 (x0 , y0 , dmap , f , maxlength ):
460
485
"""2nd-order Runge-Kutta algorithm with adaptive step size.
461
486
462
487
This method is also referred to as the improved Euler's method, or Heun's
@@ -532,7 +557,7 @@ def _integrate_rk12(x0, y0, dmap, f):
532
557
dmap .update_trajectory (xi , yi )
533
558
except InvalidIndexError :
534
559
break
535
- if (stotal + ds ) > 2 :
560
+ if (stotal + ds ) > maxlength :
536
561
break
537
562
stotal += ds
538
563
0 commit comments