Skip to content

Commit 9c9e0a0

Browse files
authored
Merge pull request #20207 from anntzer/aat
Move axisartist towards using standard Transforms.
2 parents 50c6fa9 + 7ec424a commit 9c9e0a0

File tree

3 files changed

+49
-25
lines changed

3 files changed

+49
-25
lines changed

lib/mpl_toolkits/axisartist/floating_axes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,8 @@ def get_tick_iterators(self, axes):
7474
ymin, ymax = sorted(extremes[2:])
7575

7676
def transform_xy(x, y):
77-
x1, y1 = grid_finder.transform_xy(x, y)
78-
x2, y2 = axes.transData.transform(np.array([x1, y1]).T).T
79-
return x2, y2
77+
trf = grid_finder.get_transform() + axes.transData
78+
return trf.transform(np.column_stack([x, y])).T
8079

8180
if self.nth_coord == 0:
8281
mask = (ymin <= yy0) & (yy0 <= ymax)

lib/mpl_toolkits/axisartist/grid_finder.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,34 @@ def _add_pad(self, x_min, x_max, y_min, y_max):
9292
return x_min - dx, x_max + dx, y_min - dy, y_max + dy
9393

9494

95+
class _User2DTransform(Transform):
96+
"""A transform defined by two user-set functions."""
97+
98+
input_dims = output_dims = 2
99+
100+
def __init__(self, forward, backward):
101+
"""
102+
Parameters
103+
----------
104+
forward, backward : callable
105+
The forward and backward transforms, taking ``x`` and ``y`` as
106+
separate arguments and returning ``(tr_x, tr_y)``.
107+
"""
108+
# The normal Matplotlib convention would be to take and return an
109+
# (N, 2) array but axisartist uses the transposed version.
110+
super().__init__()
111+
self._forward = forward
112+
self._backward = backward
113+
114+
def transform_non_affine(self, values):
115+
# docstring inherited
116+
return np.transpose(self._forward(*np.transpose(values)))
117+
118+
def inverted(self):
119+
# docstring inherited
120+
return type(self)(self._backward, self._forward)
121+
122+
95123
class GridFinder:
96124
def __init__(self,
97125
transform,
@@ -123,7 +151,7 @@ def __init__(self,
123151
self.grid_locator2 = grid_locator2
124152
self.tick_formatter1 = tick_formatter1
125153
self.tick_formatter2 = tick_formatter2
126-
self.update_transform(transform)
154+
self.set_transform(transform)
127155

128156
def get_grid_info(self, x1, y1, x2, y2):
129157
"""
@@ -214,27 +242,26 @@ def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
214242

215243
return gi
216244

217-
def update_transform(self, aux_trans):
218-
if not isinstance(aux_trans, Transform) and len(aux_trans) != 2:
219-
raise TypeError("'aux_trans' must be either a Transform instance "
220-
"or a pair of callables")
221-
self._aux_transform = aux_trans
245+
def set_transform(self, aux_trans):
246+
if isinstance(aux_trans, Transform):
247+
self._aux_transform = aux_trans
248+
elif len(aux_trans) == 2 and all(map(callable, aux_trans)):
249+
self._aux_transform = _User2DTransform(*aux_trans)
250+
else:
251+
raise TypeError("'aux_trans' must be either a Transform "
252+
"instance or a pair of callables")
253+
254+
def get_transform(self):
255+
return self._aux_transform
256+
257+
update_transform = set_transform # backcompat alias.
222258

223259
def transform_xy(self, x, y):
224-
aux_trf = self._aux_transform
225-
if isinstance(aux_trf, Transform):
226-
return aux_trf.transform(np.column_stack([x, y])).T
227-
else:
228-
transform_xy, inv_transform_xy = aux_trf
229-
return transform_xy(x, y)
260+
return self._aux_transform.transform(np.column_stack([x, y])).T
230261

231262
def inv_transform_xy(self, x, y):
232-
aux_trf = self._aux_transform
233-
if isinstance(aux_trf, Transform):
234-
return aux_trf.inverted().transform(np.column_stack([x, y])).T
235-
else:
236-
transform_xy, inv_transform_xy = aux_trf
237-
return inv_transform_xy(x, y)
263+
return self._aux_transform.inverted().transform(
264+
np.column_stack([x, y])).T
238265

239266
def update(self, **kw):
240267
for k in kw:

lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,10 +201,8 @@ def get_tick_iterators(self, axes):
201201
xx0 = xx0[mask]
202202

203203
def transform_xy(x, y):
204-
x1, y1 = grid_finder.transform_xy(x, y)
205-
x2y2 = axes.transData.transform(np.array([x1, y1]).transpose())
206-
x2, y2 = x2y2.transpose()
207-
return x2, y2
204+
trf = grid_finder.get_transform() + axes.transData
205+
return trf.transform(np.column_stack([x, y])).T
208206

209207
# find angles
210208
if self.nth_coord == 0:

0 commit comments

Comments
 (0)