Skip to content

Commit aa7bded

Browse files
committed
Add a helper to generate xy coordinates for AxisArtistHelper.
AxisArtistHelper can generate either x or y ticks/gridlines depending on the value of self.nth_coord. The implementation often requires generating e.g. shape (2,) arrays such that the nth_coord column is set to a tick position, and the 1-nth_coord column has is set to 0. This is currently done using constructs like ``verts = [0, 0]; verts[self.nth_coord] = value`` where the mutation doesn't really help legibility. Instead, introduce a ``_to_xy`` helper that allows writing ``to_xy(variable=x, fixed=0)``.
1 parent a6da11e commit aa7bded

File tree

2 files changed

+30
-33
lines changed

2 files changed

+30
-33
lines changed
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
``passthru_pt``
2+
~~~~~~~~~~~~~~~
3+
This attribute of ``AxisArtistHelper``\s is deprecated.

lib/mpl_toolkits/axisartist/axislines.py

Lines changed: 27 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -108,41 +108,43 @@ def update_lim(self, axes):
108108
delta2 = _api.deprecated("3.6")(
109109
property(lambda self: 0.00001, lambda self, value: None))
110110

111+
def _to_xy(self, *, variable, fixed):
112+
"""
113+
Stack the *variable* and *fixed* array-likes along the last axis so
114+
that *variable* is at the ``self.nth_coord`` position.
115+
"""
116+
if self.nth_coord == 0:
117+
return np.stack(np.broadcast_arrays(variable, fixed), axis=-1)
118+
elif self.nth_coord == 1:
119+
return np.stack(np.broadcast_arrays(fixed, variable), axis=-1)
120+
else:
121+
raise ValueError("Unxpected nth_coord")
122+
111123
class Fixed(_Base):
112124
"""Helper class for a fixed (in the axes coordinate) axis."""
113125

126+
# deprecated with passthru_pt
114127
_default_passthru_pt = dict(left=(0, 0),
115128
right=(1, 0),
116129
bottom=(0, 0),
117130
top=(0, 1))
131+
passthru_pt = _api.deprecated("3.6")(property(
132+
lambda self: self._default_passthru_pt[self._loc]))
118133

119134
def __init__(self, loc, nth_coord=None):
120135
"""
121136
nth_coord = along which coordinate value varies
122137
in 2D, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
123138
"""
124-
_api.check_in_list(["left", "right", "bottom", "top"], loc=loc)
125139
self._loc = loc
126-
127-
if nth_coord is None:
128-
if loc in ["left", "right"]:
129-
nth_coord = 1
130-
elif loc in ["bottom", "top"]:
131-
nth_coord = 0
132-
133-
self.nth_coord = nth_coord
134-
140+
self._pos = _api.check_getitem(
141+
{"bottom": 0, "top": 1, "left": 0, "right": 1}, loc=loc)
142+
self.nth_coord = (
143+
nth_coord if nth_coord is not None else
144+
{"bottom": 0, "top": 0, "left": 1, "right": 1}[loc])
135145
super().__init__()
136-
137-
self.passthru_pt = self._default_passthru_pt[loc]
138-
139-
_verts = np.array([[0., 0.],
140-
[1., 1.]])
141-
fixed_coord = 1 - nth_coord
142-
_verts[:, fixed_coord] = self.passthru_pt[fixed_coord]
143-
144146
# axis line in transAxes
145-
self._path = Path(_verts)
147+
self._path = Path(self._to_xy(variable=(0, 1), fixed=self._pos))
146148

147149
def get_nth_coord(self):
148150
return self.nth_coord
@@ -225,8 +227,7 @@ def get_tick_iterators(self, axes):
225227

226228
def _f(locs, labels):
227229
for x, l in zip(locs, labels):
228-
c = list(self.passthru_pt) # copy
229-
c[self.nth_coord] = x
230+
c = self._to_xy(variable=x, fixed=self._pos)
230231
# check if the tick point is inside axes
231232
c2 = tick_to_axes.transform(c)
232233
if mpl.transforms._interval_contains_close(
@@ -243,15 +244,10 @@ def __init__(self, axes, nth_coord,
243244
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
244245

245246
def get_line(self, axes):
246-
_verts = np.array([[0., 0.],
247-
[1., 1.]])
248-
249247
fixed_coord = 1 - self.nth_coord
250248
data_to_axes = axes.transData - axes.transAxes
251249
p = data_to_axes.transform([self._value, self._value])
252-
_verts[:, fixed_coord] = p[fixed_coord]
253-
254-
return Path(_verts)
250+
return Path(self._to_xy(variable=(0, 1), fixed=p[fixed_coord]))
255251

256252
def get_line_transform(self, axes):
257253
return axes.transAxes
@@ -266,13 +262,12 @@ def get_axislabel_pos_angle(self, axes):
266262
get_label_transform() returns a transform of (transAxes+offset)
267263
"""
268264
angle = [0, 90][self.nth_coord]
269-
_verts = [0.5, 0.5]
270265
fixed_coord = 1 - self.nth_coord
271266
data_to_axes = axes.transData - axes.transAxes
272267
p = data_to_axes.transform([self._value, self._value])
273-
_verts[fixed_coord] = p[fixed_coord]
274-
if 0 <= _verts[fixed_coord] <= 1:
275-
return _verts, angle
268+
verts = self._to_xy(variable=0.5, fixed=p[fixed_coord])
269+
if 0 <= verts[fixed_coord] <= 1:
270+
return verts, angle
276271
else:
277272
return None, None
278273

@@ -298,8 +293,7 @@ def get_tick_iterators(self, axes):
298293

299294
def _f(locs, labels):
300295
for x, l in zip(locs, labels):
301-
c = [self._value, self._value]
302-
c[self.nth_coord] = x
296+
c = self._to_xy(variable=x, fixed=self._value)
303297
c1, c2 = data_to_axes.transform(c)
304298
if 0 <= c1 <= 1 and 0 <= c2 <= 1:
305299
yield c, angle_normal, angle_tangent, l

0 commit comments

Comments
 (0)