Skip to content

Commit 401da43

Browse files
committed
Add a helper to generate xy coordinates for AxisArtistHelper.
AxisArtistHelper can generate either x or y ticks/gridlines depending on the value of self.nth_coord. The implementation often requires generating e.g. shape (2,) arrays such that the nth_coord column is set to a tick position, and the 1-nth_coord column has is set to 0. This is currently done using constructs like ``verts = [0, 0]; verts[self.nth_coord] = value`` where the mutation doesn't really help legibility. Instead, introduce a ``_to_xy`` helper that allows writing ``to_xy(variable=x, fixed=0)``.
1 parent e9d1f9c commit 401da43

File tree

2 files changed

+40
-32
lines changed

2 files changed

+40
-32
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
``passthru_pt``
2+
~~~~~~~~~~~~~~~
3+
This attribute of ``AxisArtistHelper``\s is deprecated.

lib/mpl_toolkits/axisartist/axislines.py

+37-32
Original file line numberDiff line numberDiff line change
@@ -95,38 +95,51 @@ def update_lim(self, axes):
9595
delta2 = _api.deprecated("3.6")(
9696
property(lambda self: 0.00001, lambda self, value: None))
9797

98+
def _to_xy(self, values, const):
99+
"""
100+
Create a (len(values), 2)-shape array representing (x, y) pairs.
101+
102+
*values* go into the coordinate determined by ``self.nth_coord``.
103+
The other coordinate is filled with the constant *const*.
104+
105+
Example::
106+
107+
>>> self.nth_coord = 0
108+
>>> self._to_xy([1, 2, 3], const=0)
109+
array([[1, 0],
110+
[2, 0],
111+
[3, 0]])
112+
"""
113+
if self.nth_coord == 0:
114+
return np.stack(np.broadcast_arrays(values, const), axis=-1)
115+
elif self.nth_coord == 1:
116+
return np.stack(np.broadcast_arrays(const, values), axis=-1)
117+
else:
118+
raise ValueError("Unexpected nth_coord")
119+
98120
class Fixed(_Base):
99121
"""Helper class for a fixed (in the axes coordinate) axis."""
100122

123+
# deprecated with passthru_pt
101124
_default_passthru_pt = dict(left=(0, 0),
102125
right=(1, 0),
103126
bottom=(0, 0),
104127
top=(0, 1))
128+
passthru_pt = _api.deprecated("3.7")(property(
129+
lambda self: self._default_passthru_pt[self._loc]))
105130

106131
def __init__(self, loc, nth_coord=None):
107132
"""``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis."""
108133
_api.check_in_list(["left", "right", "bottom", "top"], loc=loc)
109134
self._loc = loc
110-
111-
if nth_coord is None:
112-
if loc in ["left", "right"]:
113-
nth_coord = 1
114-
else: # "bottom", "top"
115-
nth_coord = 0
116-
117-
self.nth_coord = nth_coord
118-
135+
self._pos = _api.check_getitem(
136+
{"bottom": 0, "top": 1, "left": 0, "right": 1}, loc=loc)
137+
self.nth_coord = (
138+
nth_coord if nth_coord is not None else
139+
{"bottom": 0, "top": 0, "left": 1, "right": 1}[loc])
119140
super().__init__()
120-
121-
self.passthru_pt = self._default_passthru_pt[loc]
122-
123-
_verts = np.array([[0., 0.],
124-
[1., 1.]])
125-
fixed_coord = 1 - nth_coord
126-
_verts[:, fixed_coord] = self.passthru_pt[fixed_coord]
127-
128141
# axis line in transAxes
129-
self._path = Path(_verts)
142+
self._path = Path(self._to_xy((0, 1), const=self._pos))
130143

131144
def get_nth_coord(self):
132145
return self.nth_coord
@@ -209,8 +222,7 @@ def get_tick_iterators(self, axes):
209222

210223
def _f(locs, labels):
211224
for x, l in zip(locs, labels):
212-
c = list(self.passthru_pt) # copy
213-
c[self.nth_coord] = x
225+
c = self._to_xy(x, const=self._pos)
214226
# check if the tick point is inside axes
215227
c2 = tick_to_axes.transform(c)
216228
if mpl.transforms._interval_contains_close(
@@ -227,15 +239,10 @@ def __init__(self, axes, nth_coord,
227239
self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
228240

229241
def get_line(self, axes):
230-
_verts = np.array([[0., 0.],
231-
[1., 1.]])
232-
233242
fixed_coord = 1 - self.nth_coord
234243
data_to_axes = axes.transData - axes.transAxes
235244
p = data_to_axes.transform([self._value, self._value])
236-
_verts[:, fixed_coord] = p[fixed_coord]
237-
238-
return Path(_verts)
245+
return Path(self._to_xy((0, 1), const=p[fixed_coord]))
239246

240247
def get_line_transform(self, axes):
241248
return axes.transAxes
@@ -250,13 +257,12 @@ def get_axislabel_pos_angle(self, axes):
250257
get_label_transform() returns a transform of (transAxes+offset)
251258
"""
252259
angle = [0, 90][self.nth_coord]
253-
_verts = [0.5, 0.5]
254260
fixed_coord = 1 - self.nth_coord
255261
data_to_axes = axes.transData - axes.transAxes
256262
p = data_to_axes.transform([self._value, self._value])
257-
_verts[fixed_coord] = p[fixed_coord]
258-
if 0 <= _verts[fixed_coord] <= 1:
259-
return _verts, angle
263+
verts = self._to_xy(0.5, const=p[fixed_coord])
264+
if 0 <= verts[fixed_coord] <= 1:
265+
return verts, angle
260266
else:
261267
return None, None
262268

@@ -282,8 +288,7 @@ def get_tick_iterators(self, axes):
282288

283289
def _f(locs, labels):
284290
for x, l in zip(locs, labels):
285-
c = [self._value, self._value]
286-
c[self.nth_coord] = x
291+
c = self._to_xy(x, const=self._value)
287292
c1, c2 = data_to_axes.transform(c)
288293
if 0 <= c1 <= 1 and 0 <= c2 <= 1:
289294
yield c, angle_normal, angle_tangent, l

0 commit comments

Comments
 (0)