diff --git a/lib/matplotlib/quiver.py b/lib/matplotlib/quiver.py index 9b63e7734198..3f6ca4c2a374 100644 --- a/lib/matplotlib/quiver.py +++ b/lib/matplotlib/quiver.py @@ -381,27 +381,55 @@ def contains(self, mouseevent): # This is a helper function that parses out the various combination of # arguments for doing colored vector plots. Pulling it out here # allows both Quiver and Barbs to use it -def _parse_args(*args): +def _parse_args(*args, **kw): X, Y, U, V, C = [None] * 5 - args = list(args) - - # The use of atleast_1d allows for handling scalar arguments while also - # keeping masked arrays - if len(args) == 3 or len(args) == 5: - C = np.atleast_1d(args.pop(-1)) - V = np.atleast_1d(args.pop(-1)) - U = np.atleast_1d(args.pop(-1)) - if U.ndim == 1: - nr, nc = 1, U.shape[0] - else: - nr, nc = U.shape - if len(args) == 2: # remaining after removing U,V,C - X, Y = [np.array(a).ravel() for a in args] - if len(X) == nc and len(Y) == nr: - X, Y = [a.ravel() for a in np.meshgrid(X, Y)] - else: - indexgrid = np.meshgrid(np.arange(nc), np.arange(nr)) - X, Y = [np.ravel(a) for a in indexgrid] + if len(kw) != 0: # some keyword arguments + # The use of atleast_1d allows for handling scalar arguments while also + # keeping masked arrays + + if len(kw) == 3 or len(kw) == 5: + if (kw.get('C') is not None): + C = np.atleast_1d(kw.pop('C')) + if (kw.get('V') is not None): + V = np.atleast_1d(kw.pop('V')) + if (kw.get('U') is not None): + U = np.atleast_1d(kw.pop('U')) + if U.ndim == 1: + nr, nc = 1, U.shape[0] + else: + nr, nc = U.shape + if len(kw) == 2: # remaining after removing U,V,C. CASE 1 + if kw.get('X') is not None: + X = np.array(kw.get('X')).ravel() + if kw.get('Y') is not None: + Y = np.array(kw.get('Y')).ravel() + if len(X) == nc and len(Y) == nr: + X, Y = [a.ravel() for a in np.meshgrid(X, Y)] + elif len(kw) == 0: # CASE 2 + indexgrid = np.meshgrid(np.arange(nc), np.arange(nr)) + X, Y = [np.ravel(a) for a in indexgrid] + if len(args) != 0: + args = list(args) + # The use of atleast_1d allows for handling scalar arguments while also + # keeping masked arrays + if len(args) == 3 or len(args) == 5: + if C is not None: + C = np.atleast_1d(args.pop(-1)) + if V is None: + V = np.atleast_1d(args.pop(-1)) + if U is None: + U = np.atleast_1d(args.pop(-1)) + if U.ndim == 1: + nr, nc = 1, U.shape[0] + else: + nr, nc = U.shape + if len(args) == 2: # remaining after removing U,V,C. CASE 1 + X, Y = [np.array(a).ravel() for a in args] + if len(X) == nc and len(Y) == nr: + X, Y = [a.ravel() for a in np.meshgrid(X, Y)] + else: # CASE 2 + indexgrid = np.meshgrid(np.arange(nc), np.arange(nr)) + X, Y = [np.ravel(a) for a in indexgrid] return X, Y, U, V, C @@ -443,7 +471,17 @@ def __init__(self, ax, *args, %s """ self.ax = ax - X, Y, U, V, C = _parse_args(*args) + X, Y, U, V, C = _parse_args(*args, **kw) + if kw.get('U') is not None: # Resetting **kw to the way it was without these + kw.pop('U') + if kw.get('V') is not None: + kw.pop('V') + if kw.get('X') is not None: + kw.pop('X') + if kw.get('Y') is not None: + kw.pop('Y') + if kw.get('C') is not None: + kw.pop('C') self.X = X self.Y = Y self.XY = np.column_stack((X, Y)) diff --git a/lib/matplotlib/tests/test_collections.py b/lib/matplotlib/tests/test_collections.py index 6812ee1ad427..f230cd107295 100644 --- a/lib/matplotlib/tests/test_collections.py +++ b/lib/matplotlib/tests/test_collections.py @@ -416,6 +416,21 @@ def test_quiver_limits(): q = plt.quiver(x, y, u, v) assert q.get_datalim(ax.transData).bounds == (0., 0., 7., 9.) + x, y = np.arange(8), np.arange(10) + u = v = np.linspace(0, 10, 80).reshape(10, 8) + q = plt.quiver(X = x, Y = y, U = u, V = v) + assert q.get_datalim(ax.transData).bounds == (0., 0., 7., 9.) + + x, y = np.arange(8), np.arange(10) + u = v = np.linspace(0, 10, 80).reshape(10, 8) + q = plt.quiver(U = u, V = v) + assert q.get_datalim(ax.transData).bounds == (0., 0., 7., 9.) + + x, y = np.arange(8), np.arange(10) + u = v = np.linspace(0, 10, 80).reshape(10, 8) + q = plt.quiver(x, y, U = u, V = v) + assert q.get_datalim(ax.transData).bounds == (0., 0., 7., 9.) + plt.figure() ax = plt.axes() x = np.linspace(-5, 10, 20)