Skip to content

Commit baf6f83

Browse files
authored
Merge pull request #25007 from oscargus/3daxesrefactor
2 parents 9c48e4b + 3a5c93d commit baf6f83

File tree

1 file changed

+23
-26
lines changed

1 file changed

+23
-26
lines changed

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 23 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -959,15 +959,11 @@ def disable_mouse_rotation(self):
959959
self.mouse_init(rotate_btn=[], pan_btn=[], zoom_btn=[])
960960

961961
def can_zoom(self):
962-
"""
963-
Return whether this Axes supports the zoom box button functionality.
964-
"""
962+
# doc-string inherited
965963
return True
966964

967965
def can_pan(self):
968-
"""
969-
Return whether this Axes supports the pan button functionality.
970-
"""
966+
# doc-string inherited
971967
return True
972968

973969
def sharez(self, other):
@@ -1002,17 +998,17 @@ def _button_press(self, event):
1002998
if event.inaxes == self:
1003999
self.button_pressed = event.button
10041000
self._sx, self._sy = event.xdata, event.ydata
1005-
toolbar = getattr(self.figure.canvas, "toolbar")
1001+
toolbar = self.figure.canvas.toolbar
10061002
if toolbar and toolbar._nav_stack() is None:
1007-
self.figure.canvas.toolbar.push_current()
1003+
toolbar.push_current()
10081004

10091005
def _button_release(self, event):
10101006
self.button_pressed = None
1011-
toolbar = getattr(self.figure.canvas, "toolbar")
1007+
toolbar = self.figure.canvas.toolbar
10121008
# backend_bases.release_zoom and backend_bases.release_pan call
10131009
# push_current, so check the navigation mode so we don't call it twice
10141010
if toolbar and self.get_navigate_mode() is None:
1015-
self.figure.canvas.toolbar.push_current()
1011+
toolbar.push_current()
10161012

10171013
def _get_view(self):
10181014
# docstring inherited
@@ -1597,12 +1593,6 @@ def plot_surface(self, X, Y, Z, *, norm=None, vmin=None,
15971593

15981594
fcolors = kwargs.pop('facecolors', None)
15991595

1600-
if fcolors is None:
1601-
color = kwargs.pop('color', None)
1602-
if color is None:
1603-
color = self._get_lines.get_next_color()
1604-
color = np.array(mcolors.to_rgba(color))
1605-
16061596
cmap = kwargs.get('cmap', None)
16071597
shade = kwargs.pop('shade', cmap is None)
16081598
if shade is None:
@@ -1677,6 +1667,11 @@ def plot_surface(self, X, Y, Z, *, norm=None, vmin=None,
16771667
if norm is not None:
16781668
polyc.set_norm(norm)
16791669
else:
1670+
color = kwargs.pop('color', None)
1671+
if color is None:
1672+
color = self._get_lines.get_next_color()
1673+
color = np.array(mcolors.to_rgba(color))
1674+
16801675
polyc = art3d.Poly3DCollection(
16811676
polys, facecolors=color, shade=shade,
16821677
lightsource=lightsource, **kwargs)
@@ -2548,31 +2543,33 @@ def quiver(self, X, Y, Z, U, V, W, *,
25482543
:class:`.Line3DCollection`
25492544
"""
25502545

2551-
def calc_arrows(UVW, angle=15):
2546+
def calc_arrows(UVW):
25522547
# get unit direction vector perpendicular to (u, v, w)
25532548
x = UVW[:, 0]
25542549
y = UVW[:, 1]
25552550
norm = np.linalg.norm(UVW[:, :2], axis=1)
25562551
x_p = np.divide(y, norm, where=norm != 0, out=np.zeros_like(x))
25572552
y_p = np.divide(-x, norm, where=norm != 0, out=np.ones_like(x))
25582553
# compute the two arrowhead direction unit vectors
2559-
ra = math.radians(angle)
2560-
c = math.cos(ra)
2561-
s = math.sin(ra)
2554+
rangle = math.radians(15)
2555+
c = math.cos(rangle)
2556+
s = math.sin(rangle)
25622557
# construct the rotation matrices of shape (3, 3, n)
2558+
r13 = y_p * s
2559+
r32 = x_p * s
2560+
r12 = x_p * y_p * (1 - c)
25632561
Rpos = np.array(
2564-
[[c + (x_p ** 2) * (1 - c), x_p * y_p * (1 - c), y_p * s],
2565-
[y_p * x_p * (1 - c), c + (y_p ** 2) * (1 - c), -x_p * s],
2566-
[-y_p * s, x_p * s, np.full_like(x_p, c)]])
2562+
[[c + (x_p ** 2) * (1 - c), r12, r13],
2563+
[r12, c + (y_p ** 2) * (1 - c), -r32],
2564+
[-r13, r32, np.full_like(x_p, c)]])
25672565
# opposite rotation negates all the sin terms
25682566
Rneg = Rpos.copy()
25692567
Rneg[[0, 1, 2, 2], [2, 2, 0, 1]] *= -1
25702568
# Batch n (3, 3) x (3) matrix multiplications ((3, 3, n) x (n, 3)).
25712569
Rpos_vecs = np.einsum("ij...,...j->...i", Rpos, UVW)
25722570
Rneg_vecs = np.einsum("ij...,...j->...i", Rneg, UVW)
25732571
# Stack into (n, 2, 3) result.
2574-
head_dirs = np.stack([Rpos_vecs, Rneg_vecs], axis=1)
2575-
return head_dirs
2572+
return np.stack([Rpos_vecs, Rneg_vecs], axis=1)
25762573

25772574
had_data = self.has_data()
25782575

@@ -2934,7 +2931,7 @@ def errorbar(self, x, y, z, zerr=None, yerr=None, xerr=None, fmt='',
29342931
draws error bars on a subset of the data. *errorevery* =N draws
29352932
error bars on the points (x[::N], y[::N], z[::N]).
29362933
*errorevery* =(start, N) draws error bars on the points
2937-
(x[start::N], y[start::N], z[start::N]). e.g. errorevery=(6, 3)
2934+
(x[start::N], y[start::N], z[start::N]). e.g. *errorevery* =(6, 3)
29382935
adds error bars to the data at (x[6], x[9], x[12], x[15], ...).
29392936
Used to avoid overlapping error bars when two series share x-axis
29402937
values.

0 commit comments

Comments
 (0)