Skip to content

Commit cec78a7

Browse files
committed
Merge pull request #3149 from WeatherGod/3dquiver
3dquiver rebranch
2 parents bcb57af + 3112e67 commit cec78a7

File tree

14 files changed

+35216
-0
lines changed

14 files changed

+35216
-0
lines changed

doc/mpl_toolkits/mplot3d/tutorial.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ Bar plots
112112

113113
.. plot:: mpl_examples/mplot3d/bars3d_demo.py
114114

115+
.. _quiver3d:
116+
117+
Quiver
118+
====================
119+
.. automethod:: Axes3D.quiver
120+
121+
.. plot:: mpl_examples/mplot3d/quiver3d_demo.py
122+
115123
.. _2dcollections3d:
116124

117125
2D plots in 3D

doc/users/whats_new.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,16 @@ to ensure that the calculated theta value was between the range of 0 and 2 * pi
224224
since the problem was that the value can become negative after applying the
225225
direction and rotation to the theta calculation.
226226

227+
Simple quiver plot for mplot3d toolkit
228+
``````````````````````````````````````
229+
A team of students in an *Engineering Large Software Systems* course, taught
230+
by Prof. Anya Tafliovich at the University of Toronto, implemented a simple
231+
version of a quiver plot in 3D space for the mplot3d toolkit as one of their
232+
term project. This feature is documented in :func:`~mpl_toolkits.mplot3d.Axes3D.quiver`.
233+
The team members are: Ryan Steve D'Souza, Victor B, xbtsw, Yang Wang, David,
234+
Caradec Bisesar and Vlad Vassilovski.
235+
236+
.. plot:: mpl_examples/quiver3d_demo.py
227237

228238
Date handling
229239
-------------

examples/mplot3d/quiver3d_demo.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from mpl_toolkits.mplot3d import axes3d
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
5+
fig = plt.figure()
6+
ax = fig.gca(projection='3d')
7+
8+
x, y, z = np.meshgrid(np.arange(-0.8, 1, 0.2),
9+
np.arange(-0.8, 1, 0.2),
10+
np.arange(-0.8, 1, 0.8))
11+
12+
u = np.sin(np.pi * x) * np.cos(np.pi * y) * np.cos(np.pi * z)
13+
v = -np.cos(np.pi * x) * np.sin(np.pi * y) * np.cos(np.pi * z)
14+
w = (np.sqrt(2.0 / 3.0) * np.cos(np.pi * x) * np.cos(np.pi * y) *
15+
np.sin(np.pi * z))
16+
17+
ax.quiver(x, y, z, u, v, w, length=0.1)
18+
19+
plt.show()
20+

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 199 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212
from __future__ import (absolute_import, division, print_function,
1313
unicode_literals)
14+
import math
1415

1516
import six
1617
from six.moves import map, xrange, zip
@@ -2413,6 +2414,204 @@ def set_title(self, label, fontdict=None, loc='center', **kwargs):
24132414
return ret
24142415
set_title.__doc__ = maxes.Axes.set_title.__doc__
24152416

2417+
def quiver(self, *args, **kwargs):
2418+
"""
2419+
Plot a 3D field of arrows.
2420+
2421+
call signatures::
2422+
2423+
quiver(X, Y, Z, U, V, W, **kwargs)
2424+
2425+
Arguments:
2426+
2427+
*X*, *Y*, *Z*:
2428+
The x, y and z coordinates of the arrow locations
2429+
2430+
*U*, *V*, *W*:
2431+
The direction vector that the arrow is pointing
2432+
2433+
The arguments could be array-like or scalars, so long as they
2434+
they can be broadcast together. The arguments can also be
2435+
masked arrays. If an element in any of argument is masked, then
2436+
that corresponding quiver element will not be plotted.
2437+
2438+
Keyword arguments:
2439+
2440+
*length*: [1.0 | float]
2441+
The length of each quiver, default to 1.0, the unit is
2442+
the same with the axes
2443+
2444+
*arrow_length_ratio*: [0.3 | float]
2445+
The ratio of the arrow head with respect to the quiver,
2446+
default to 0.3
2447+
2448+
Any additional keyword arguments are delegated to
2449+
:class:`~matplotlib.collections.LineCollection`
2450+
2451+
"""
2452+
def calc_arrow(u, v, w, angle=15):
2453+
"""
2454+
To calculate the arrow head. (u, v, w) should be unit vector.
2455+
"""
2456+
2457+
# this part figures out the axis of rotation to use
2458+
2459+
# use unit vector perpendicular to (u,v,w) when |w|=1, by default
2460+
x, y, z = 0, 1, 0
2461+
2462+
# get the norm
2463+
norm = math.sqrt(v**2 + u**2)
2464+
# normalize it if it is safe
2465+
if norm > 0:
2466+
# get unit direction vector perpendicular to (u,v,w)
2467+
x, y = v/norm, -u/norm
2468+
2469+
# this function takes an angle, and rotates the (u,v,w)
2470+
# angle degrees around (x,y,z)
2471+
def rotatefunction(angle):
2472+
ra = math.radians(angle)
2473+
c = math.cos(ra)
2474+
s = math.sin(ra)
2475+
2476+
# construct the rotation matrix
2477+
R = np.matrix([[c+(x**2)*(1-c), x*y*(1-c)-z*s, x*z*(1-c)+y*s],
2478+
[y*x*(1-c)+z*s, c+(y**2)*(1-c), y*z*(1-c)-x*s],
2479+
[z*x*(1-c)-y*s, z*y*(1-c)+x*s, c+(z**2)*(1-c)]])
2480+
2481+
# construct the column vector for (u,v,w)
2482+
line = np.matrix([[u],[v],[w]])
2483+
2484+
# use numpy to multiply them to get the rotated vector
2485+
rotatedline = R*line
2486+
2487+
# return the rotated (u,v,w) from the computed matrix
2488+
return (rotatedline[0,0], rotatedline[1,0], rotatedline[2,0])
2489+
2490+
# compute and return the two arrowhead direction unit vectors
2491+
return rotatefunction(angle), rotatefunction(-angle)
2492+
2493+
def point_vector_to_line(point, vector, length):
2494+
"""
2495+
use a point and vector to generate lines
2496+
"""
2497+
lines = []
2498+
for var in np.linspace(0, length, num=2):
2499+
lines.append(list(zip(*(point - var * vector))))
2500+
lines = np.array(lines).swapaxes(0, 1)
2501+
return lines.tolist()
2502+
2503+
had_data = self.has_data()
2504+
2505+
# handle kwargs
2506+
# shaft length
2507+
length = kwargs.pop('length', 1)
2508+
# arrow length ratio to the shaft length
2509+
arrow_length_ratio = kwargs.pop('arrow_length_ratio', 0.3)
2510+
2511+
# handle args
2512+
if len(args) < 6:
2513+
ValueError('Wrong number of arguments')
2514+
argi = 6
2515+
# first 6 arguments are X, Y, Z, U, V, W
2516+
input_args = args[:argi]
2517+
# if any of the args are scalar, convert into list
2518+
input_args = [[k] if isinstance(k, (int, float)) else k
2519+
for k in input_args]
2520+
2521+
# extract the masks, if any
2522+
masks = [k.mask for k in input_args if isinstance(k, np.ma.MaskedArray)]
2523+
# broadcast to match the shape
2524+
bcast = np.broadcast_arrays(*(input_args + masks))
2525+
input_args = bcast[:argi]
2526+
masks = bcast[argi:]
2527+
if masks:
2528+
# combine the masks into one
2529+
mask = reduce(np.logical_or, masks)
2530+
# put mask on and compress
2531+
input_args = [np.ma.array(k, mask=mask).compressed()
2532+
for k in input_args]
2533+
else:
2534+
input_args = [k.flatten() for k in input_args]
2535+
2536+
if any(len(v) == 0 for v in input_args):
2537+
# No quivers, so just make an empty collection and return early
2538+
linec = art3d.Line3DCollection([], *args[6:], **kwargs)
2539+
self.add_collection(linec)
2540+
return linec
2541+
2542+
points = input_args[:3]
2543+
vectors = input_args[3:]
2544+
2545+
# Below assertions must be true before proceed
2546+
# must all be ndarray
2547+
assert all(isinstance(k, np.ndarray) for k in input_args)
2548+
# must all in same shape
2549+
assert len(set([k.shape for k in input_args])) == 1
2550+
2551+
# X, Y, Z, U, V, W
2552+
coords = (np.array(k) if not isinstance(k, np.ndarray) else k
2553+
for k in args)
2554+
coords = [k.flatten() for k in coords]
2555+
xs, ys, zs, us, vs, ws = coords
2556+
lines = []
2557+
2558+
# for each arrow
2559+
for i in range(xs.shape[0]):
2560+
# calulate body
2561+
x = xs[i]
2562+
y = ys[i]
2563+
z = zs[i]
2564+
u = us[i]
2565+
v = vs[i]
2566+
w = ws[i]
2567+
if any(k is np.ma.masked for k in [x, y, z, u, v, w]):
2568+
continue
2569+
2570+
# (u,v,w) expected to be normalized, recursive to fix A=0 scenario.
2571+
if u == 0 and v == 0 and w == 0:
2572+
raise ValueError("u,v,w can't be all zero")
2573+
2574+
# normalize
2575+
norm = math.sqrt(u ** 2 + v ** 2 + w ** 2)
2576+
u /= norm
2577+
v /= norm
2578+
w /= norm
2579+
2580+
# draw main line
2581+
t = np.linspace(0, length, num=20)
2582+
lx = x - t * u
2583+
ly = y - t * v
2584+
lz = z - t * w
2585+
line = list(zip(lx, ly, lz))
2586+
lines.append(line)
2587+
2588+
d1, d2 = calc_arrow(u, v, w)
2589+
ua1, va1, wa1 = d1[0], d1[1], d1[2]
2590+
ua2, va2, wa2 = d2[0], d2[1], d2[2]
2591+
2592+
t = np.linspace(0, length * arrow_length_ratio, num=20)
2593+
la1x = x - t * ua1
2594+
la1y = y - t * va1
2595+
la1z = z - t * wa1
2596+
la2x = x - t * ua2
2597+
la2y = y - t * va2
2598+
la2z = z - t * wa2
2599+
2600+
line = list(zip(la1x, la1y, la1z))
2601+
lines.append(line)
2602+
line = list(zip(la2x, la2y, la2z))
2603+
lines.append(line)
2604+
2605+
linec = art3d.Line3DCollection(lines, *args[6:], **kwargs)
2606+
self.add_collection(linec)
2607+
2608+
self.auto_scale_xyz(xs, ys, zs, had_data)
2609+
2610+
return linec
2611+
2612+
quiver3D = quiver
2613+
2614+
24162615
def get_test_data(delta=0.05):
24172616
'''
24182617
Return a tuple X, Y, Z with a test data set.
Binary file not shown.

0 commit comments

Comments
 (0)