Skip to content

Commit c3d6d26

Browse files
committed
Numpify quiver.py
svn path=/trunk/matplotlib/; revision=3396
1 parent 7b0f040 commit c3d6d26

File tree

1 file changed

+76
-72
lines changed

1 file changed

+76
-72
lines changed

lib/matplotlib/quiver.py

+76-72
Original file line numberDiff line numberDiff line change
@@ -140,41 +140,43 @@
140140
of the arrow+label key object.
141141
"""
142142

143-
from matplotlib.collections import PolyCollection
144-
from matplotlib.mlab import meshgrid
145-
from matplotlib import numerix as nx
146-
from matplotlib import transforms as T
147-
from matplotlib.text import Text
148-
from matplotlib.artist import Artist
149-
from matplotlib.font_manager import FontProperties
143+
import numpy as npy
144+
import matplotlib.numerix.npyma as ma
145+
import matplotlib.collections as collections
146+
import matplotlib.transforms as transforms
147+
import matplotlib.text as text
148+
import matplotlib.artist as artist
149+
import matplotlib.font_manager as font_manager
150150
import math
151151

152152

153-
class QuiverKey(Artist):
153+
class QuiverKey(artist.Artist):
154154
""" Labelled arrow for use as a quiver plot scale key.
155155
"""
156156
halign = {'N': 'center', 'S': 'center', 'E': 'left', 'W': 'right'}
157157
valign = {'N': 'bottom', 'S': 'top', 'E': 'center', 'W': 'center'}
158158
pivot = {'N': 'mid', 'S': 'mid', 'E': 'tip', 'W': 'tail'}
159159

160160
def __init__(self, Q, X, Y, U, label, **kw):
161-
Artist.__init__(self)
161+
artist.Artist.__init__(self)
162162
self.Q = Q
163163
self.X = X
164164
self.Y = Y
165165
self.U = U
166166
self.coord = kw.pop('coordinates', 'axes')
167167
self.color = kw.pop('color', None)
168168
self.label = label
169-
self.labelsep = T.Value(kw.pop('labelsep', 0.1)) * Q.ax.figure.dpi
169+
self.labelsep = (transforms.Value(kw.pop('labelsep', 0.1))
170+
* Q.ax.figure.dpi)
170171
self.labelpos = kw.pop('labelpos', 'N')
171172
self.labelcolor = kw.pop('labelcolor', None)
172173
self.fontproperties = kw.pop('fontproperties', dict())
173174
self.kw = kw
174-
self.text = Text(text=label,
175-
horizontalalignment=self.halign[self.labelpos],
176-
verticalalignment=self.valign[self.labelpos],
177-
fontproperties=FontProperties(**self.fontproperties))
175+
_fp = self.fontproperties
176+
self.text = text.Text(text=label,
177+
horizontalalignment=self.halign[self.labelpos],
178+
verticalalignment=self.valign[self.labelpos],
179+
fontproperties=font_manager.FontProperties(**_fp))
178180
if self.labelcolor is not None:
179181
self.text.set_color(self.labelcolor)
180182
self._initialized = False
@@ -187,11 +189,12 @@ def _init(self):
187189
self._set_transform()
188190
_pivot = self.Q.pivot
189191
self.Q.pivot = self.pivot[self.labelpos]
190-
self.verts = self.Q._make_verts(nx.array([self.U]), nx.zeros((1,)))
192+
self.verts = self.Q._make_verts(npy.array([self.U]),
193+
npy.zeros((1,)))
191194
self.Q.pivot = _pivot
192195
kw = self.Q.polykw
193196
kw.update(self.kw)
194-
self.vector = PolyCollection(self.verts,
197+
self.vector = collections.PolyCollection(self.verts,
195198
offsets=[(self.X,self.Y)],
196199
transOffset=self.get_transform(),
197200
**kw)
@@ -234,14 +237,14 @@ def _set_transform(self):
234237
self.set_transform(self.Q.ax.figure.transFigure)
235238
elif self.coord == 'inches':
236239
dx = ax.figure.dpi
237-
bb = T.Bbox(T.origin(), T.Point(dx, dx))
238-
trans = T.get_bbox_transform(T.unit_bbox(), bb)
240+
bb = transforms.Bbox(transforms.origin(), transforms.Point(dx, dx))
241+
trans = transforms.get_bbox_transform(transforms.unit_bbox(), bb)
239242
self.set_transform(trans)
240243
else:
241244
raise ValueError('unrecognized coordinates')
242245
quiverkey_doc = _quiverkey_doc
243246

244-
class Quiver(PolyCollection):
247+
class Quiver(collections.PolyCollection):
245248
"""
246249
Specialized PolyCollection for arrows.
247250
@@ -276,7 +279,7 @@ def __init__(self, ax, *args, **kw):
276279
self.pivot = kw.pop('pivot', 'tail')
277280
kw.setdefault('facecolors', self.color)
278281
kw.setdefault('linewidths', (0,))
279-
PolyCollection.__init__(self, None, offsets=zip(X, Y),
282+
collections.PolyCollection.__init__(self, None, offsets=zip(X, Y),
280283
transOffset=ax.transData, **kw)
281284
self.polykw = kw
282285
self.set_UVC(U, V, C)
@@ -295,21 +298,22 @@ def _parse_args(self, *args):
295298
X, Y, U, V, C = [None]*5
296299
args = list(args)
297300
if len(args) == 3 or len(args) == 5:
298-
C = nx.ravel(args.pop(-1))
301+
C = npy.ravel(args.pop(-1))
299302
#print 'in parse_args, C:', C
300-
V = nx.ma.asarray(args.pop(-1))
301-
U = nx.ma.asarray(args.pop(-1))
302-
nn = nx.shape(U)
303+
V = ma.asarray(args.pop(-1))
304+
U = ma.asarray(args.pop(-1))
305+
nn = npy.shape(U)
303306
nc = nn[0]
304307
nr = 1
305308
if len(nn) > 1:
306309
nr = nn[1]
307310
if len(args) == 2:
308-
X, Y = [nx.ravel(a) for a in args]
311+
X, Y = [npy.ravel(a) for a in args]
309312
if len(X) == nc and len(Y) == nr:
310-
X, Y = [nx.ravel(a) for a in meshgrid(X, Y)]
313+
X, Y = [npy.ravel(a) for a in npy.meshgrid(X, Y)]
311314
else:
312-
X, Y = [nx.ravel(a) for a in meshgrid(nx.arange(nc), nx.arange(nr))]
315+
indexgrid = npy.meshgrid(npy.arange(nc), npy.arange(nr))
316+
X, Y = [npy.ravel(a) for a in indexgrid]
313317
return X, Y, U, V, C
314318

315319
def _init(self):
@@ -331,13 +335,13 @@ def draw(self, renderer):
331335
verts = self._make_verts(self.U, self.V)
332336
self.set_verts(verts)
333337
self._new_UV = False
334-
PolyCollection.draw(self, renderer)
338+
collections.PolyCollection.draw(self, renderer)
335339

336340
def set_UVC(self, U, V, C=None):
337-
self.U = nx.ma.ravel(U)
338-
self.V = nx.ma.ravel(V)
341+
self.U = ma.ravel(U)
342+
self.V = ma.ravel(V)
339343
if C is not None:
340-
self.set_array(nx.ravel(C))
344+
self.set_array(npy.ravel(C))
341345
self._new_UV = True
342346

343347
def _set_transform(self):
@@ -356,86 +360,86 @@ def _set_transform(self):
356360
elif self.units == 'height':
357361
dx = ax.bbox.ur().y() - ax.bbox.ll().y()
358362
elif self.units == 'dots':
359-
dx = T.Value(1)
363+
dx = transforms.Value(1)
360364
elif self.units == 'inches':
361365
dx = ax.figure.dpi
362366
else:
363367
raise ValueError('unrecognized units')
364-
bb = T.Bbox(T.origin(), T.Point(dx, dx))
365-
trans = T.get_bbox_transform(T.unit_bbox(), bb)
368+
bb = transforms.Bbox(transforms.origin(), transforms.Point(dx, dx))
369+
trans = transforms.get_bbox_transform(transforms.unit_bbox(), bb)
366370
self.set_transform(trans)
367371
return trans
368372

369373
def _make_verts(self, U, V):
370374
uv = U+V*1j
371-
uv = nx.ravel(nx.ma.filled(uv,nx.nan))
372-
a = nx.absolute(uv)
375+
uv = npy.ravel(ma.filled(uv,npy.nan))
376+
a = npy.absolute(uv)
373377
if self.scale is None:
374378
sn = max(10, math.sqrt(self.N))
375379

376380
# get valid values for average
377381
# (complicated by support for 3 array packages)
378-
a_valid_cond = ~nx.isnan(a)
379-
a_valid_idx = nx.nonzero(a_valid_cond)
382+
a_valid_cond = ~npy.isnan(a)
383+
a_valid_idx = npy.nonzero(a_valid_cond)
380384
if isinstance(a_valid_idx,tuple):
381385
# numpy.nonzero returns tuple
382386
a_valid_idx = a_valid_idx[0]
383-
valid_a = nx.take(a,a_valid_idx)
387+
valid_a = npy.take(a,a_valid_idx)
384388

385-
scale = 1.8 * nx.average(valid_a) * sn # crude auto-scaling
389+
scale = 1.8 * npy.average(valid_a) * sn # crude auto-scaling
386390
scale = scale/self.span
387391
self.scale = scale
388392
length = a/(self.scale*self.width)
389393
X, Y = self._h_arrows(length)
390-
xy = (X+Y*1j) * nx.exp(1j*nx.angle(uv[...,nx.newaxis]))*self.width
391-
xy = xy[:,:,nx.newaxis]
392-
XY = nx.concatenate((xy.real, xy.imag), axis=2)
394+
xy = (X+Y*1j) * npy.exp(1j*npy.angle(uv[...,npy.newaxis]))*self.width
395+
xy = xy[:,:,npy.newaxis]
396+
XY = npy.concatenate((xy.real, xy.imag), axis=2)
393397
return XY
394398

395399

396400
def _h_arrows(self, length):
397401
""" length is in arrow width units """
398402
minsh = self.minshaft * self.headlength
399403
N = len(length)
400-
length = nx.reshape(length, (N,1))
401-
x = nx.array([0, -self.headaxislength,
402-
-self.headlength, 0], nx.Float64)
403-
x = x + nx.array([0,1,1,1]) * length
404-
y = 0.5 * nx.array([1, 1, self.headwidth, 0], nx.Float64)
405-
y = nx.repeat(y[nx.newaxis,:], N)
406-
x0 = nx.array([0, minsh-self.headaxislength,
407-
minsh-self.headlength, minsh], nx.Float64)
408-
y0 = 0.5 * nx.array([1, 1, self.headwidth, 0], nx.Float64)
404+
length = npy.reshape(length, (N,1))
405+
x = npy.array([0, -self.headaxislength,
406+
-self.headlength, 0], npy.float64)
407+
x = x + npy.array([0,1,1,1]) * length
408+
y = 0.5 * npy.array([1, 1, self.headwidth, 0], npy.float64)
409+
y = npy.repeat(y[npy.newaxis,:], N, axis=0)
410+
x0 = npy.array([0, minsh-self.headaxislength,
411+
minsh-self.headlength, minsh], npy.float64)
412+
y0 = 0.5 * npy.array([1, 1, self.headwidth, 0], npy.float64)
409413
ii = [0,1,2,3,2,1,0]
410-
X = nx.take(x, ii, 1)
411-
Y = nx.take(y, ii, 1)
414+
X = npy.take(x, ii, 1)
415+
Y = npy.take(y, ii, 1)
412416
Y[:, 3:] *= -1
413-
X0 = nx.take(x0, ii)
414-
Y0 = nx.take(y0, ii)
417+
X0 = npy.take(x0, ii)
418+
Y0 = npy.take(y0, ii)
415419
Y0[3:] *= -1
416420
shrink = length/minsh
417-
X0 = shrink * X0[nx.newaxis,:]
418-
Y0 = shrink * Y0[nx.newaxis,:]
419-
short = nx.repeat(length < minsh, 7, 1)
421+
X0 = shrink * X0[npy.newaxis,:]
422+
Y0 = shrink * Y0[npy.newaxis,:]
423+
short = npy.repeat(length < minsh, 7, axis=1)
420424
#print 'short', length < minsh
421-
X = nx.where(short, X0, X)
422-
Y = nx.where(short, Y0, Y)
425+
X = npy.where(short, X0, X)
426+
Y = npy.where(short, Y0, Y)
423427
if self.pivot[:3] == 'mid':
424-
X -= 0.5 * X[:,3, nx.newaxis]
428+
X -= 0.5 * X[:,3, npy.newaxis]
425429
elif self.pivot[:3] == 'tip':
426-
X = X - X[:,3, nx.newaxis] #numpy bug? using -= does not
430+
X = X - X[:,3, npy.newaxis] #numpy bug? using -= does not
427431
# work here unless we multiply
428432
# by a float first, as with 'mid'.
429433
tooshort = length < self.minlength
430-
if nx.any(tooshort):
431-
th = nx.arange(0,7,1, nx.Float64) * (nx.pi/3.0)
432-
x1 = nx.cos(th) * self.minlength * 0.5
433-
y1 = nx.sin(th) * self.minlength * 0.5
434-
X1 = nx.repeat(x1[nx.newaxis, :], N, 0)
435-
Y1 = nx.repeat(y1[nx.newaxis, :], N, 0)
436-
tooshort = nx.repeat(tooshort, 7, 1)
437-
X = nx.where(tooshort, X1, X)
438-
Y = nx.where(tooshort, Y1, Y)
434+
if npy.any(tooshort):
435+
th = npy.arange(0,7,1, npy.float64) * (npy.pi/3.0)
436+
x1 = npy.cos(th) * self.minlength * 0.5
437+
y1 = npy.sin(th) * self.minlength * 0.5
438+
X1 = npy.repeat(x1[npy.newaxis, :], N, axis=0)
439+
Y1 = npy.repeat(y1[npy.newaxis, :], N, axis=0)
440+
tooshort = npy.repeat(tooshort, 7, 1)
441+
X = npy.where(tooshort, X1, X)
442+
Y = npy.where(tooshort, Y1, Y)
439443
return X, Y
440444

441445
quiver_doc = _quiver_doc

0 commit comments

Comments
 (0)