Skip to content

Commit 1229099

Browse files
committed
Merge branch 'fix-1026' of git://github.com/xbtsw/matplotlib into 3dquiver
Conflicts: doc/users/whats_new.rst
2 parents f1689bb + 1ee5161 commit 1229099

File tree

8 files changed

+18216
-0
lines changed

8 files changed

+18216
-0
lines changed

doc/mpl_toolkits/mplot3d/tutorial.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,14 @@ Bar plots
112112

113113
.. plot:: mpl_examples/mplot3d/bars3d_demo.py
114114

115+
.. _quiver3d:
116+
117+
Quiver
118+
====================
119+
.. automethod:: Axes3D.quiver
120+
121+
.. plot:: mpl_examples/mplot3d/quiver3d_demo.py
122+
115123
.. _2dcollections3d:
116124

117125
2D plots in 3D

doc/users/whats_new.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,16 @@ to ensure that the calculated theta value was between the range of 0 and 2 * pi
224224
since the problem was that the value can become negative after applying the
225225
direction and rotation to the theta calculation.
226226

227+
Simple quiver plot for mplot3d toolkit
228+
``````````````````````````````````````
229+
A team of students in an *Engineering Large Software Systems* course, taught
230+
by Prof. Anya Tafliovich at the University of Toronto, implemented a simple
231+
version of a quiver plot in 3D space for the mplot3d toolkit as one of their
232+
term project. This feature is documented in :func:`~mpl_toolkits.mplot3d.Axes3D.quiver`.
233+
The team members are: Ryan Steve D'Souza, Victor B, xbtsw, Yang Wang, David,
234+
Caradec Bisesar and Vlad Vassilovski.
235+
236+
.. plot:: mpl_examples/quiver3d_demo.py
227237

228238
Date handling
229239
-------------

examples/mplot3d/quiver3d_demo.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from mpl_toolkits.mplot3d import axes3d
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
5+
fig = plt.figure()
6+
ax = fig.gca(projection='3d')
7+
8+
x, y, z = np.meshgrid(np.arange(-0.8, 1, 0.2),
9+
np.arange(-0.8, 1, 0.2),
10+
np.arange(-0.8, 1, 0.8))
11+
12+
u = np.sin(np.pi * x) * np.cos(np.pi * y) * np.cos(np.pi * z)
13+
v = -np.cos(np.pi * x) * np.sin(np.pi * y) * np.cos(np.pi * z)
14+
w = np.sqrt(2.0 / 3.0) * np.cos(np.pi * x) * np.cos(np.pi * y) * \
15+
np.sin(np.pi * z)
16+
17+
ax.quiver(x, y, z, u, v, w, length=0.1)
18+
19+
plt.show()

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
"""
1212
from __future__ import (absolute_import, division, print_function,
1313
unicode_literals)
14+
import math
1415

1516
import six
1617
from six.moves import map, xrange, zip
@@ -33,6 +34,7 @@
3334
from . import art3d
3435
from . import proj3d
3536
from . import axis3d
37+
from mpl_toolkits.mplot3d.art3d import Line3DCollection
3638

3739
def unit_bbox():
3840
box = Bbox(np.array([[0, 0], [1, 1]]))
@@ -2413,6 +2415,191 @@ def set_title(self, label, fontdict=None, loc='center', **kwargs):
24132415
return ret
24142416
set_title.__doc__ = maxes.Axes.set_title.__doc__
24152417

2418+
def quiver(self, *args, **kwargs):
2419+
"""
2420+
Plot a 3D field of arrows.
2421+
2422+
call signatures::
2423+
2424+
quiver(X, Y, Z, U, V, W, **kwargs)
2425+
2426+
Arguments:
2427+
2428+
*X*, *Y*, *Z*:
2429+
The x, y and z coordinates of the arrow locations
2430+
2431+
*U*, *V*, *W*:
2432+
The direction vector that the arrow is pointing
2433+
2434+
The arguments could be iterable or scalars they will be broadcast together. The arguments can
2435+
also be masked arrays, if a position in any of argument is masked, then the corresponding
2436+
quiver will not be plotted.
2437+
2438+
Keyword arguments:
2439+
2440+
*length*: [1.0 | float]
2441+
The length of each quiver, default to 1.0, the unit is the same with the axes
2442+
2443+
*arrow_length_ratio*: [0.3 | float]
2444+
The ratio of the arrow head with respect to the quiver, default to 0.3
2445+
2446+
Any additional keyword arguments are delegated to :class:`~matplotlib.collections.LineCollection`
2447+
2448+
"""
2449+
def calc_arrow(u, v, w, angle=15):
2450+
"""
2451+
To calculate the arrow head. (u, v, w) should be unit vector.
2452+
"""
2453+
2454+
# this part figures out the axis of rotation to use
2455+
2456+
# use unit vector perpendicular to (u,v,w) when |w|=1, by default
2457+
x, y, z = 0, 1, 0
2458+
2459+
# get the norm
2460+
norm = math.sqrt(v**2 + u**2)
2461+
# normalize it if it is safe
2462+
if norm > 0:
2463+
# get unit direction vector perpendicular to (u,v,w)
2464+
x, y = v/norm, -u/norm
2465+
2466+
# this function takes an angle, and rotates the (u,v,w)
2467+
# angle degrees around (x,y,z)
2468+
def rotatefunction(angle):
2469+
ra = math.radians(angle)
2470+
c = math.cos(ra)
2471+
s = math.sin(ra)
2472+
2473+
# construct the rotation matrix
2474+
R = np.matrix([[c+(x**2)*(1-c), x*y*(1-c)-z*s, x*z*(1-c)+y*s],
2475+
[y*x*(1-c)+z*s, c+(y**2)*(1-c), y*z*(1-c)-x*s],
2476+
[z*x*(1-c)-y*s, z*y*(1-c)+x*s, c+(z**2)*(1-c)]])
2477+
2478+
# construct the column vector for (u,v,w)
2479+
line = np.matrix([[u],[v],[w]])
2480+
2481+
# use numpy to multiply them to get the rotated vector
2482+
rotatedline = R*line
2483+
2484+
# return the rotated (u,v,w) from the computed matrix
2485+
return (rotatedline[0,0], rotatedline[1,0], rotatedline[2,0])
2486+
2487+
# compute and return the two arrowhead direction unit vectors
2488+
return rotatefunction(angle), rotatefunction(-angle)
2489+
2490+
def point_vector_to_line(point, vector, length):
2491+
"""
2492+
use a point and vector to generate lines
2493+
"""
2494+
lines = []
2495+
for var in np.linspace(0, length, num=2):
2496+
lines.append(list(zip(*(point - var * vector))))
2497+
lines = np.array(lines).swapaxes(0, 1)
2498+
return lines.tolist()
2499+
2500+
had_data = self.has_data()
2501+
2502+
# handle kwargs
2503+
# shaft length
2504+
length = kwargs.pop('length', 1)
2505+
# arrow length ratio to the shaft length
2506+
arrow_length_ratio = kwargs.pop('arrow_length_ratio', 0.3)
2507+
2508+
# handle args
2509+
if len(args) < 6:
2510+
ValueError('Wrong number of arguments')
2511+
argi = 6
2512+
# first 6 arguments are X, Y, Z, U, V, W
2513+
input_args = args[:argi]
2514+
# if any of the args are scalar, convert into list
2515+
input_args = [[k] if isinstance(k, (int, float)) else k for k in input_args]
2516+
# extract the masks, if any
2517+
masks = [k.mask for k in input_args if isinstance(k, np.ma.MaskedArray)]
2518+
# broadcast to match the shape
2519+
bcast = np.broadcast_arrays(*(input_args + masks))
2520+
input_args = bcast[:argi]
2521+
masks = bcast[argi:]
2522+
if masks:
2523+
# combine the masks into one
2524+
mask = reduce(np.logical_or, masks)
2525+
# put mask on and compress
2526+
input_args = [np.ma.array(k, mask=mask).compressed() for k in input_args]
2527+
else:
2528+
input_args = [k.flatten() for k in input_args]
2529+
2530+
points = input_args[:3]
2531+
vectors = input_args[3:]
2532+
2533+
# Below assertions must be true before proceed
2534+
# must all be ndarray
2535+
assert all([isinstance(k, np.ndarray) for k in input_args])
2536+
# must all in same shape
2537+
assert len(set([k.shape for k in input_args])) == 1
2538+
2539+
2540+
# X, Y, Z, U, V, W
2541+
coords = list(map(lambda k: np.array(k) if not isinstance(k, np.ndarray) else k, args))
2542+
coords = [k.flatten() for k in coords]
2543+
xs, ys, zs, us, vs, ws = coords
2544+
lines = []
2545+
2546+
# for each arrow
2547+
for i in xrange(xs.shape[0]):
2548+
# calulate body
2549+
x = xs[i]
2550+
y = ys[i]
2551+
z = zs[i]
2552+
u = us[i]
2553+
v = vs[i]
2554+
w = ws[i]
2555+
if any([k is np.ma.masked for k in [x, y, z, u, v, w]]):
2556+
continue
2557+
2558+
# (u,v,w) expected to be normalized, recursive to fix A=0 scenario.
2559+
if u == 0 and v == 0 and w == 0:
2560+
raise ValueError("u,v,w can't be all zero")
2561+
2562+
# normalize
2563+
norm = math.sqrt(u ** 2 + v ** 2 + w ** 2)
2564+
u /= norm
2565+
v /= norm
2566+
w /= norm
2567+
2568+
# draw main line
2569+
t = np.linspace(0, length, num=20)
2570+
lx = x - t * u
2571+
ly = y - t * v
2572+
lz = z - t * w
2573+
line = list(zip(lx, ly, lz))
2574+
lines.append(line)
2575+
2576+
d1, d2 = calc_arrow(u, v, w)
2577+
ua1, va1, wa1 = d1[0], d1[1], d1[2]
2578+
ua2, va2, wa2 = d2[0], d2[1], d2[2]
2579+
2580+
t = np.linspace(0, length * arrow_length_ratio, num=20)
2581+
la1x = x - t * ua1
2582+
la1y = y - t * va1
2583+
la1z = z - t * wa1
2584+
la2x = x - t * ua2
2585+
la2y = y - t * va2
2586+
la2z = z - t * wa2
2587+
2588+
line = list(zip(la1x, la1y, la1z))
2589+
lines.append(line)
2590+
line = list(zip(la2x, la2y, la2z))
2591+
lines.append(line)
2592+
2593+
linec = Line3DCollection(lines, *args[6:], **kwargs)
2594+
self.add_collection(linec)
2595+
2596+
self.auto_scale_xyz(xs, ys, zs, had_data)
2597+
2598+
return linec
2599+
2600+
quiver3D = quiver
2601+
2602+
24162603
def get_test_data(delta=0.05):
24172604
'''
24182605
Return a tuple X, Y, Z with a test data set.
Binary file not shown.

0 commit comments

Comments
 (0)