Skip to content

Commit 8840d6e

Browse files
committed
Move axisartist towards using standard Transforms.
axisartist can generate "slanted" or "curved" axes defined using either a custom Transform, or a custom pair of callables (forward, backward) which define the Transform. Instead of repeatedly testing the two cases in `transform_xy` and `inv_transform_xy`, just wrap the pair-of-callables into a custom Transform subclass (which should probably not be moved to the main library, as it uses the transposed convention compared to the usual one), and expose it through a getter. This allows later combining this transform using "standard" transform addition (and get benefits such as single multiplication of matrices in affine transforms, exact cancellation of inverses, or (later) more accurate transforming of arcs in polar transforms), instead of having to manually call successive transforms. See e.g. the two changes in the local transform_xy definitions.
1 parent a9c5224 commit 8840d6e

File tree

3 files changed

+49
-25
lines changed

3 files changed

+49
-25
lines changed

lib/mpl_toolkits/axisartist/floating_axes.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,8 @@ def get_tick_iterators(self, axes):
7474
ymin, ymax = sorted(extremes[2:])
7575

7676
def transform_xy(x, y):
77-
x1, y1 = grid_finder.transform_xy(x, y)
78-
x2, y2 = axes.transData.transform(np.array([x1, y1]).T).T
79-
return x2, y2
77+
trf = grid_finder.get_transform() + axes.transData
78+
return trf.transform(np.column_stack([x, y])).T
8079

8180
if self.nth_coord == 0:
8281
mask = (ymin <= yy0) & (yy0 <= ymax)

lib/mpl_toolkits/axisartist/grid_finder.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,34 @@ def _add_pad(self, x_min, x_max, y_min, y_max):
9292
return x_min - dx, x_max + dx, y_min - dy, y_max + dy
9393

9494

95+
class _User2DTransform(Transform):
96+
"""A transform defined by two user-set functions."""
97+
98+
input_dims = output_dims = 2
99+
100+
def __init__(self, forward, backward):
101+
"""
102+
Parameters
103+
----------
104+
forward, backward : callable
105+
The forward and backward transforms, as taking ``x`` and ``y`` as
106+
separate arguments and returning ``(tr_x, tr_y)``.
107+
"""
108+
# The normal Matplotlib convention would be to take and return an
109+
# (N, 2) array but axisartist uses the transposed version.
110+
super().__init__()
111+
self._forward = forward
112+
self._backward = backward
113+
114+
def transform_non_affine(self, values):
115+
# docstring inherited
116+
return np.transpose(self._forward(*np.transpose(values)))
117+
118+
def inverted(self):
119+
# docstring inherited
120+
return type(self)(self._backward, self._forward)
121+
122+
95123
class GridFinder:
96124
def __init__(self,
97125
transform,
@@ -123,7 +151,7 @@ def __init__(self,
123151
self.grid_locator2 = grid_locator2
124152
self.tick_formatter1 = tick_formatter1
125153
self.tick_formatter2 = tick_formatter2
126-
self.update_transform(transform)
154+
self.set_transform(transform)
127155

128156
def get_grid_info(self, x1, y1, x2, y2):
129157
"""
@@ -214,27 +242,26 @@ def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
214242

215243
return gi
216244

217-
def update_transform(self, aux_trans):
218-
if not isinstance(aux_trans, Transform) and len(aux_trans) != 2:
219-
raise TypeError("'aux_trans' must be either a Transform instance "
220-
"or a pair of callables")
221-
self._aux_transform = aux_trans
245+
def set_transform(self, aux_trans):
246+
if isinstance(aux_trans, Transform):
247+
self._aux_transform = aux_trans
248+
elif len(aux_trans) == 2 and all(map(callable, aux_trans)):
249+
self._aux_transform = _User2DTransform(*aux_trans)
250+
else:
251+
raise TypeError("'aux_trans' must be either a Transform "
252+
"instance or a pair of callables")
253+
254+
def get_transform(self):
255+
return self._aux_transform
256+
257+
update_transform = set_transform # backcompat alias.
222258

223259
def transform_xy(self, x, y):
224-
aux_trf = self._aux_transform
225-
if isinstance(aux_trf, Transform):
226-
return aux_trf.transform(np.column_stack([x, y])).T
227-
else:
228-
transform_xy, inv_transform_xy = aux_trf
229-
return transform_xy(x, y)
260+
return self._aux_transform.transform(np.column_stack([x, y])).T
230261

231262
def inv_transform_xy(self, x, y):
232-
aux_trf = self._aux_transform
233-
if isinstance(aux_trf, Transform):
234-
return aux_trf.inverted().transform(np.column_stack([x, y])).T
235-
else:
236-
transform_xy, inv_transform_xy = aux_trf
237-
return inv_transform_xy(x, y)
263+
return self._aux_transform.inverted().transform(
264+
np.column_stack([x, y])).T
238265

239266
def update(self, **kw):
240267
for k in kw:

lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,10 +202,8 @@ def get_tick_iterators(self, axes):
202202
xx0 = xx0[mask]
203203

204204
def transform_xy(x, y):
205-
x1, y1 = grid_finder.transform_xy(x, y)
206-
x2y2 = axes.transData.transform(np.array([x1, y1]).transpose())
207-
x2, y2 = x2y2.transpose()
208-
return x2, y2
205+
trf = grid_finder.get_transform() + axes.transData
206+
return trf.transform(np.column_stack([x, y])).T
209207

210208
# find angles
211209
if self.nth_coord == 0:

0 commit comments

Comments
 (0)