diff --git a/lib/mpl_toolkits/axisartist/floating_axes.py b/lib/mpl_toolkits/axisartist/floating_axes.py index e1958939c7d5..761b26d4c01c 100644 --- a/lib/mpl_toolkits/axisartist/floating_axes.py +++ b/lib/mpl_toolkits/axisartist/floating_axes.py @@ -74,9 +74,8 @@ def get_tick_iterators(self, axes): ymin, ymax = sorted(extremes[2:]) def transform_xy(x, y): - x1, y1 = grid_finder.transform_xy(x, y) - x2, y2 = axes.transData.transform(np.array([x1, y1]).T).T - return x2, y2 + trf = grid_finder.get_transform() + axes.transData + return trf.transform(np.column_stack([x, y])).T if self.nth_coord == 0: mask = (ymin <= yy0) & (yy0 <= ymax) diff --git a/lib/mpl_toolkits/axisartist/grid_finder.py b/lib/mpl_toolkits/axisartist/grid_finder.py index 3ddc303093a3..22a139e40e85 100644 --- a/lib/mpl_toolkits/axisartist/grid_finder.py +++ b/lib/mpl_toolkits/axisartist/grid_finder.py @@ -92,6 +92,34 @@ def _add_pad(self, x_min, x_max, y_min, y_max): return x_min - dx, x_max + dx, y_min - dy, y_max + dy +class _User2DTransform(Transform): + """A transform defined by two user-set functions.""" + + input_dims = output_dims = 2 + + def __init__(self, forward, backward): + """ + Parameters + ---------- + forward, backward : callable + The forward and backward transforms, taking ``x`` and ``y`` as + separate arguments and returning ``(tr_x, tr_y)``. + """ + # The normal Matplotlib convention would be to take and return an + # (N, 2) array but axisartist uses the transposed version. + super().__init__() + self._forward = forward + self._backward = backward + + def transform_non_affine(self, values): + # docstring inherited + return np.transpose(self._forward(*np.transpose(values))) + + def inverted(self): + # docstring inherited + return type(self)(self._backward, self._forward) + + class GridFinder: def __init__(self, transform, @@ -123,7 +151,7 @@ def __init__(self, self.grid_locator2 = grid_locator2 self.tick_formatter1 = tick_formatter1 self.tick_formatter2 = tick_formatter2 - self.update_transform(transform) + self.set_transform(transform) def get_grid_info(self, x1, y1, x2, y2): """ @@ -214,27 +242,26 @@ def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb): return gi - def update_transform(self, aux_trans): - if not isinstance(aux_trans, Transform) and len(aux_trans) != 2: - raise TypeError("'aux_trans' must be either a Transform instance " - "or a pair of callables") - self._aux_transform = aux_trans + def set_transform(self, aux_trans): + if isinstance(aux_trans, Transform): + self._aux_transform = aux_trans + elif len(aux_trans) == 2 and all(map(callable, aux_trans)): + self._aux_transform = _User2DTransform(*aux_trans) + else: + raise TypeError("'aux_trans' must be either a Transform " + "instance or a pair of callables") + + def get_transform(self): + return self._aux_transform + + update_transform = set_transform # backcompat alias. def transform_xy(self, x, y): - aux_trf = self._aux_transform - if isinstance(aux_trf, Transform): - return aux_trf.transform(np.column_stack([x, y])).T - else: - transform_xy, inv_transform_xy = aux_trf - return transform_xy(x, y) + return self._aux_transform.transform(np.column_stack([x, y])).T def inv_transform_xy(self, x, y): - aux_trf = self._aux_transform - if isinstance(aux_trf, Transform): - return aux_trf.inverted().transform(np.column_stack([x, y])).T - else: - transform_xy, inv_transform_xy = aux_trf - return inv_transform_xy(x, y) + return self._aux_transform.inverted().transform( + np.column_stack([x, y])).T def update(self, **kw): for k in kw: diff --git a/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py b/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py index bb67dd0d78c8..6598c9a84d49 100644 --- a/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py +++ b/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py @@ -202,10 +202,8 @@ def get_tick_iterators(self, axes): xx0 = xx0[mask] def transform_xy(x, y): - x1, y1 = grid_finder.transform_xy(x, y) - x2y2 = axes.transData.transform(np.array([x1, y1]).transpose()) - x2, y2 = x2y2.transpose() - return x2, y2 + trf = grid_finder.get_transform() + axes.transData + return trf.transform(np.column_stack([x, y])).T # find angles if self.nth_coord == 0: