From 8840d6e59cfdc1836a702e94d8ccda6ee342475b Mon Sep 17 00:00:00 2001 From: Antony Lee Date: Tue, 11 May 2021 22:38:15 +0200 Subject: [PATCH 1/2] Move axisartist towards using standard Transforms. axisartist can generate "slanted" or "curved" axes defined using either a custom Transform, or a custom pair of callables (forward, backward) which define the Transform. Instead of repeatedly testing the two cases in `transform_xy` and `inv_transform_xy`, just wrap the pair-of-callables into a custom Transform subclass (which should probably not be moved to the main library, as it uses the transposed convention compared to the usual one), and expose it through a getter. This allows later combining this transform using "standard" transform addition (and get benefits such as single multiplication of matrices in affine transforms, exact cancellation of inverses, or (later) more accurate transforming of arcs in polar transforms), instead of having to manually call successive transforms. See e.g. the two changes in the local transform_xy definitions. --- lib/mpl_toolkits/axisartist/floating_axes.py | 5 +- lib/mpl_toolkits/axisartist/grid_finder.py | 63 +++++++++++++------ .../axisartist/grid_helper_curvelinear.py | 6 +- 3 files changed, 49 insertions(+), 25 deletions(-) diff --git a/lib/mpl_toolkits/axisartist/floating_axes.py b/lib/mpl_toolkits/axisartist/floating_axes.py index e1958939c7d5..761b26d4c01c 100644 --- a/lib/mpl_toolkits/axisartist/floating_axes.py +++ b/lib/mpl_toolkits/axisartist/floating_axes.py @@ -74,9 +74,8 @@ def get_tick_iterators(self, axes): ymin, ymax = sorted(extremes[2:]) def transform_xy(x, y): - x1, y1 = grid_finder.transform_xy(x, y) - x2, y2 = axes.transData.transform(np.array([x1, y1]).T).T - return x2, y2 + trf = grid_finder.get_transform() + axes.transData + return trf.transform(np.column_stack([x, y])).T if self.nth_coord == 0: mask = (ymin <= yy0) & (yy0 <= ymax) diff --git a/lib/mpl_toolkits/axisartist/grid_finder.py b/lib/mpl_toolkits/axisartist/grid_finder.py index 3ddc303093a3..207bf396df6d 100644 --- a/lib/mpl_toolkits/axisartist/grid_finder.py +++ b/lib/mpl_toolkits/axisartist/grid_finder.py @@ -92,6 +92,34 @@ def _add_pad(self, x_min, x_max, y_min, y_max): return x_min - dx, x_max + dx, y_min - dy, y_max + dy +class _User2DTransform(Transform): + """A transform defined by two user-set functions.""" + + input_dims = output_dims = 2 + + def __init__(self, forward, backward): + """ + Parameters + ---------- + forward, backward : callable + The forward and backward transforms, as taking ``x`` and ``y`` as + separate arguments and returning ``(tr_x, tr_y)``. + """ + # The normal Matplotlib convention would be to take and return an + # (N, 2) array but axisartist uses the transposed version. + super().__init__() + self._forward = forward + self._backward = backward + + def transform_non_affine(self, values): + # docstring inherited + return np.transpose(self._forward(*np.transpose(values))) + + def inverted(self): + # docstring inherited + return type(self)(self._backward, self._forward) + + class GridFinder: def __init__(self, transform, @@ -123,7 +151,7 @@ def __init__(self, self.grid_locator2 = grid_locator2 self.tick_formatter1 = tick_formatter1 self.tick_formatter2 = tick_formatter2 - self.update_transform(transform) + self.set_transform(transform) def get_grid_info(self, x1, y1, x2, y2): """ @@ -214,27 +242,26 @@ def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb): return gi - def update_transform(self, aux_trans): - if not isinstance(aux_trans, Transform) and len(aux_trans) != 2: - raise TypeError("'aux_trans' must be either a Transform instance " - "or a pair of callables") - self._aux_transform = aux_trans + def set_transform(self, aux_trans): + if isinstance(aux_trans, Transform): + self._aux_transform = aux_trans + elif len(aux_trans) == 2 and all(map(callable, aux_trans)): + self._aux_transform = _User2DTransform(*aux_trans) + else: + raise TypeError("'aux_trans' must be either a Transform " + "instance or a pair of callables") + + def get_transform(self): + return self._aux_transform + + update_transform = set_transform # backcompat alias. def transform_xy(self, x, y): - aux_trf = self._aux_transform - if isinstance(aux_trf, Transform): - return aux_trf.transform(np.column_stack([x, y])).T - else: - transform_xy, inv_transform_xy = aux_trf - return transform_xy(x, y) + return self._aux_transform.transform(np.column_stack([x, y])).T def inv_transform_xy(self, x, y): - aux_trf = self._aux_transform - if isinstance(aux_trf, Transform): - return aux_trf.inverted().transform(np.column_stack([x, y])).T - else: - transform_xy, inv_transform_xy = aux_trf - return inv_transform_xy(x, y) + return self._aux_transform.inverted().transform( + np.column_stack([x, y])).T def update(self, **kw): for k in kw: diff --git a/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py b/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py index bb67dd0d78c8..6598c9a84d49 100644 --- a/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py +++ b/lib/mpl_toolkits/axisartist/grid_helper_curvelinear.py @@ -202,10 +202,8 @@ def get_tick_iterators(self, axes): xx0 = xx0[mask] def transform_xy(x, y): - x1, y1 = grid_finder.transform_xy(x, y) - x2y2 = axes.transData.transform(np.array([x1, y1]).transpose()) - x2, y2 = x2y2.transpose() - return x2, y2 + trf = grid_finder.get_transform() + axes.transData + return trf.transform(np.column_stack([x, y])).T # find angles if self.nth_coord == 0: From 7ec424af68e6b5e8b4bfd81d0715b93e0d5b79bb Mon Sep 17 00:00:00 2001 From: Jody Klymak Date: Thu, 13 May 2021 07:01:12 -0700 Subject: [PATCH 2/2] Update lib/mpl_toolkits/axisartist/grid_finder.py Co-authored-by: Elliott Sales de Andrade --- lib/mpl_toolkits/axisartist/grid_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/mpl_toolkits/axisartist/grid_finder.py b/lib/mpl_toolkits/axisartist/grid_finder.py index 207bf396df6d..22a139e40e85 100644 --- a/lib/mpl_toolkits/axisartist/grid_finder.py +++ b/lib/mpl_toolkits/axisartist/grid_finder.py @@ -102,7 +102,7 @@ def __init__(self, forward, backward): Parameters ---------- forward, backward : callable - The forward and backward transforms, as taking ``x`` and ``y`` as + The forward and backward transforms, taking ``x`` and ``y`` as separate arguments and returning ``(tr_x, tr_y)``. """ # The normal Matplotlib convention would be to take and return an