|
11 | 11 | """
|
12 | 12 | from __future__ import (absolute_import, division, print_function,
|
13 | 13 | unicode_literals)
|
| 14 | +import math |
14 | 15 |
|
15 | 16 | import six
|
16 | 17 | from six.moves import map, xrange, zip
|
|
33 | 34 | from . import art3d
|
34 | 35 | from . import proj3d
|
35 | 36 | from . import axis3d
|
| 37 | +from mpl_toolkits.mplot3d.art3d import Line3DCollection |
36 | 38 |
|
37 | 39 | def unit_bbox():
|
38 | 40 | box = Bbox(np.array([[0, 0], [1, 1]]))
|
@@ -2413,6 +2415,191 @@ def set_title(self, label, fontdict=None, loc='center', **kwargs):
|
2413 | 2415 | return ret
|
2414 | 2416 | set_title.__doc__ = maxes.Axes.set_title.__doc__
|
2415 | 2417 |
|
| 2418 | + def quiver(self, *args, **kwargs): |
| 2419 | + """ |
| 2420 | + Plot a 3D field of arrows. |
| 2421 | +
|
| 2422 | + call signatures:: |
| 2423 | +
|
| 2424 | + quiver(X, Y, Z, U, V, W, **kwargs) |
| 2425 | +
|
| 2426 | + Arguments: |
| 2427 | +
|
| 2428 | + *X*, *Y*, *Z*: |
| 2429 | + The x, y and z coordinates of the arrow locations |
| 2430 | +
|
| 2431 | + *U*, *V*, *W*: |
| 2432 | + The direction vector that the arrow is pointing |
| 2433 | +
|
| 2434 | + The arguments could be iterable or scalars they will be broadcast together. The arguments can |
| 2435 | + also be masked arrays, if a position in any of argument is masked, then the corresponding |
| 2436 | + quiver will not be plotted. |
| 2437 | +
|
| 2438 | + Keyword arguments: |
| 2439 | +
|
| 2440 | + *length*: [1.0 | float] |
| 2441 | + The length of each quiver, default to 1.0, the unit is the same with the axes |
| 2442 | +
|
| 2443 | + *arrow_length_ratio*: [0.3 | float] |
| 2444 | + The ratio of the arrow head with respect to the quiver, default to 0.3 |
| 2445 | +
|
| 2446 | + Any additional keyword arguments are delegated to :class:`~matplotlib.collections.LineCollection` |
| 2447 | +
|
| 2448 | + """ |
| 2449 | + def calc_arrow(u, v, w, angle=15): |
| 2450 | + """ |
| 2451 | + To calculate the arrow head. (u, v, w) should be unit vector. |
| 2452 | + """ |
| 2453 | + |
| 2454 | + # this part figures out the axis of rotation to use |
| 2455 | + |
| 2456 | + # use unit vector perpendicular to (u,v,w) when |w|=1, by default |
| 2457 | + x, y, z = 0, 1, 0 |
| 2458 | + |
| 2459 | + # get the norm |
| 2460 | + norm = math.sqrt(v**2 + u**2) |
| 2461 | + # normalize it if it is safe |
| 2462 | + if norm > 0: |
| 2463 | + # get unit direction vector perpendicular to (u,v,w) |
| 2464 | + x, y = v/norm, -u/norm |
| 2465 | + |
| 2466 | + # this function takes an angle, and rotates the (u,v,w) |
| 2467 | + # angle degrees around (x,y,z) |
| 2468 | + def rotatefunction(angle): |
| 2469 | + ra = math.radians(angle) |
| 2470 | + c = math.cos(ra) |
| 2471 | + s = math.sin(ra) |
| 2472 | + |
| 2473 | + # construct the rotation matrix |
| 2474 | + R = np.matrix([[c+(x**2)*(1-c), x*y*(1-c)-z*s, x*z*(1-c)+y*s], |
| 2475 | + [y*x*(1-c)+z*s, c+(y**2)*(1-c), y*z*(1-c)-x*s], |
| 2476 | + [z*x*(1-c)-y*s, z*y*(1-c)+x*s, c+(z**2)*(1-c)]]) |
| 2477 | + |
| 2478 | + # construct the column vector for (u,v,w) |
| 2479 | + line = np.matrix([[u],[v],[w]]) |
| 2480 | + |
| 2481 | + # use numpy to multiply them to get the rotated vector |
| 2482 | + rotatedline = R*line |
| 2483 | + |
| 2484 | + # return the rotated (u,v,w) from the computed matrix |
| 2485 | + return (rotatedline[0,0], rotatedline[1,0], rotatedline[2,0]) |
| 2486 | + |
| 2487 | + # compute and return the two arrowhead direction unit vectors |
| 2488 | + return rotatefunction(angle), rotatefunction(-angle) |
| 2489 | + |
| 2490 | + def point_vector_to_line(point, vector, length): |
| 2491 | + """ |
| 2492 | + use a point and vector to generate lines |
| 2493 | + """ |
| 2494 | + lines = [] |
| 2495 | + for var in np.linspace(0, length, num=2): |
| 2496 | + lines.append(list(zip(*(point - var * vector)))) |
| 2497 | + lines = np.array(lines).swapaxes(0, 1) |
| 2498 | + return lines.tolist() |
| 2499 | + |
| 2500 | + had_data = self.has_data() |
| 2501 | + |
| 2502 | + # handle kwargs |
| 2503 | + # shaft length |
| 2504 | + length = kwargs.pop('length', 1) |
| 2505 | + # arrow length ratio to the shaft length |
| 2506 | + arrow_length_ratio = kwargs.pop('arrow_length_ratio', 0.3) |
| 2507 | + |
| 2508 | + # handle args |
| 2509 | + if len(args) < 6: |
| 2510 | + ValueError('Wrong number of arguments') |
| 2511 | + argi = 6 |
| 2512 | + # first 6 arguments are X, Y, Z, U, V, W |
| 2513 | + input_args = args[:argi] |
| 2514 | + # if any of the args are scalar, convert into list |
| 2515 | + input_args = [[k] if isinstance(k, (int, float)) else k for k in input_args] |
| 2516 | + # extract the masks, if any |
| 2517 | + masks = [k.mask for k in input_args if isinstance(k, np.ma.MaskedArray)] |
| 2518 | + # broadcast to match the shape |
| 2519 | + bcast = np.broadcast_arrays(*(input_args + masks)) |
| 2520 | + input_args = bcast[:argi] |
| 2521 | + masks = bcast[argi:] |
| 2522 | + if masks: |
| 2523 | + # combine the masks into one |
| 2524 | + mask = reduce(np.logical_or, masks) |
| 2525 | + # put mask on and compress |
| 2526 | + input_args = [np.ma.array(k, mask=mask).compressed() for k in input_args] |
| 2527 | + else: |
| 2528 | + input_args = [k.flatten() for k in input_args] |
| 2529 | + |
| 2530 | + points = input_args[:3] |
| 2531 | + vectors = input_args[3:] |
| 2532 | + |
| 2533 | + # Below assertions must be true before proceed |
| 2534 | + # must all be ndarray |
| 2535 | + assert all([isinstance(k, np.ndarray) for k in input_args]) |
| 2536 | + # must all in same shape |
| 2537 | + assert len(set([k.shape for k in input_args])) == 1 |
| 2538 | + |
| 2539 | + |
| 2540 | + # X, Y, Z, U, V, W |
| 2541 | + coords = list(map(lambda k: np.array(k) if not isinstance(k, np.ndarray) else k, args)) |
| 2542 | + coords = [k.flatten() for k in coords] |
| 2543 | + xs, ys, zs, us, vs, ws = coords |
| 2544 | + lines = [] |
| 2545 | + |
| 2546 | + # for each arrow |
| 2547 | + for i in xrange(xs.shape[0]): |
| 2548 | + # calulate body |
| 2549 | + x = xs[i] |
| 2550 | + y = ys[i] |
| 2551 | + z = zs[i] |
| 2552 | + u = us[i] |
| 2553 | + v = vs[i] |
| 2554 | + w = ws[i] |
| 2555 | + if any([k is np.ma.masked for k in [x, y, z, u, v, w]]): |
| 2556 | + continue |
| 2557 | + |
| 2558 | + # (u,v,w) expected to be normalized, recursive to fix A=0 scenario. |
| 2559 | + if u == 0 and v == 0 and w == 0: |
| 2560 | + raise ValueError("u,v,w can't be all zero") |
| 2561 | + |
| 2562 | + # normalize |
| 2563 | + norm = math.sqrt(u ** 2 + v ** 2 + w ** 2) |
| 2564 | + u /= norm |
| 2565 | + v /= norm |
| 2566 | + w /= norm |
| 2567 | + |
| 2568 | + # draw main line |
| 2569 | + t = np.linspace(0, length, num=20) |
| 2570 | + lx = x - t * u |
| 2571 | + ly = y - t * v |
| 2572 | + lz = z - t * w |
| 2573 | + line = list(zip(lx, ly, lz)) |
| 2574 | + lines.append(line) |
| 2575 | + |
| 2576 | + d1, d2 = calc_arrow(u, v, w) |
| 2577 | + ua1, va1, wa1 = d1[0], d1[1], d1[2] |
| 2578 | + ua2, va2, wa2 = d2[0], d2[1], d2[2] |
| 2579 | + |
| 2580 | + t = np.linspace(0, length * arrow_length_ratio, num=20) |
| 2581 | + la1x = x - t * ua1 |
| 2582 | + la1y = y - t * va1 |
| 2583 | + la1z = z - t * wa1 |
| 2584 | + la2x = x - t * ua2 |
| 2585 | + la2y = y - t * va2 |
| 2586 | + la2z = z - t * wa2 |
| 2587 | + |
| 2588 | + line = list(zip(la1x, la1y, la1z)) |
| 2589 | + lines.append(line) |
| 2590 | + line = list(zip(la2x, la2y, la2z)) |
| 2591 | + lines.append(line) |
| 2592 | + |
| 2593 | + linec = Line3DCollection(lines, *args[6:], **kwargs) |
| 2594 | + self.add_collection(linec) |
| 2595 | + |
| 2596 | + self.auto_scale_xyz(xs, ys, zs, had_data) |
| 2597 | + |
| 2598 | + return linec |
| 2599 | + |
| 2600 | + quiver3D = quiver |
| 2601 | + |
| 2602 | + |
2416 | 2603 | def get_test_data(delta=0.05):
|
2417 | 2604 | '''
|
2418 | 2605 | Return a tuple X, Y, Z with a test data set.
|
|
0 commit comments