diff --git a/lib/matplotlib/axes/_secondary_axes.py b/lib/matplotlib/axes/_secondary_axes.py index 4c757f61c7b6..7ee3c8b13c07 100644 --- a/lib/matplotlib/axes/_secondary_axes.py +++ b/lib/matplotlib/axes/_secondary_axes.py @@ -6,6 +6,7 @@ import matplotlib.ticker as mticker from matplotlib.axes._base import _AxesBase, _TransformedBoundsLocator from matplotlib.axis import Axis +from matplotlib.transforms import Transform class SecondaryAxis(_AxesBase): @@ -144,11 +145,17 @@ def set_functions(self, functions): If a transform is supplied, then the transform must have an inverse. """ + if (isinstance(functions, tuple) and len(functions) == 2 and callable(functions[0]) and callable(functions[1])): # make an arbitrary convert from a two-tuple of functions # forward and inverse. self._functions = functions + elif isinstance(functions, Transform): + self._functions = ( + functions.transform, + lambda x: functions.inverted().transform(x) + ) elif functions is None: self._functions = (lambda x: x, lambda x: x) else: diff --git a/lib/matplotlib/tests/baseline_images/test_axes/secondary_xy.png b/lib/matplotlib/tests/baseline_images/test_axes/secondary_xy.png index bbf9f9e13211..b69241a06bc6 100644 Binary files a/lib/matplotlib/tests/baseline_images/test_axes/secondary_xy.png and b/lib/matplotlib/tests/baseline_images/test_axes/secondary_xy.png differ diff --git a/lib/matplotlib/tests/test_axes.py b/lib/matplotlib/tests/test_axes.py index 799bf5ddd774..3ae1a9376ceb 100644 --- a/lib/matplotlib/tests/test_axes.py +++ b/lib/matplotlib/tests/test_axes.py @@ -7494,6 +7494,20 @@ def test_annotate_across_transforms(): arrowprops=dict(arrowstyle="->")) +class _Translation(mtransforms.Transform): + input_dims = 1 + output_dims = 1 + + def __init__(self, dx): + self.dx = dx + + def transform(self, values): + return values + self.dx + + def inverted(self): + return _Translation(-self.dx) + + @image_comparison(['secondary_xy.png'], style='mpl20') def test_secondary_xy(): fig, axs = plt.subplots(1, 2, figsize=(10, 5), constrained_layout=True) @@ -7513,6 +7527,7 @@ def invert(x): secax(0.4, functions=(lambda x: 2 * x, lambda x: x / 2)) secax(0.6, functions=(lambda x: x**2, lambda x: x**(1/2))) secax(0.8) + secax("top" if nn == 0 else "right", functions=_Translation(2)) def test_secondary_fail():