diff --git a/doc/api/next_api_changes/deprecations/22314-AL.rst b/doc/api/next_api_changes/deprecations/22314-AL.rst new file mode 100644 index 000000000000..bea929019865 --- /dev/null +++ b/doc/api/next_api_changes/deprecations/22314-AL.rst @@ -0,0 +1,3 @@ +``passthru_pt`` +~~~~~~~~~~~~~~~ +This attribute of ``AxisArtistHelper``\s is deprecated. diff --git a/lib/mpl_toolkits/axisartist/axislines.py b/lib/mpl_toolkits/axisartist/axislines.py index c52fb347abd2..081fdadb0835 100644 --- a/lib/mpl_toolkits/axisartist/axislines.py +++ b/lib/mpl_toolkits/axisartist/axislines.py @@ -95,38 +95,46 @@ def update_lim(self, axes): delta2 = _api.deprecated("3.6")( property(lambda self: 0.00001, lambda self, value: None)) + def _to_xy(self, values, const): + """ + Create a (*values.shape, 2)-shape array representing (x, y) pairs. + + *values* go into the coordinate determined by ``self.nth_coord``. + The other coordinate is filled with the constant *const*. + + Example:: + + >>> self.nth_coord = 0 + >>> self._to_xy([1, 2, 3], const=0) + array([[1, 0], + [2, 0], + [3, 0]]) + """ + if self.nth_coord == 0: + return np.stack(np.broadcast_arrays(values, const), axis=-1) + elif self.nth_coord == 1: + return np.stack(np.broadcast_arrays(const, values), axis=-1) + else: + raise ValueError("Unexpected nth_coord") + class Fixed(_Base): """Helper class for a fixed (in the axes coordinate) axis.""" - _default_passthru_pt = dict(left=(0, 0), - right=(1, 0), - bottom=(0, 0), - top=(0, 1)) + passthru_pt = _api.deprecated("3.7")(property( + lambda self: {"left": (0, 0), "right": (1, 0), + "bottom": (0, 0), "top": (0, 1)}[self._loc])) def __init__(self, loc, nth_coord=None): """``nth_coord = 0``: x-axis; ``nth_coord = 1``: y-axis.""" _api.check_in_list(["left", "right", "bottom", "top"], loc=loc) self._loc = loc - - if nth_coord is None: - if loc in ["left", "right"]: - nth_coord = 1 - else: # "bottom", "top" - nth_coord = 0 - - self.nth_coord = nth_coord - + self._pos = {"bottom": 0, "top": 1, "left": 0, "right": 1}[loc] + self.nth_coord = ( + nth_coord if nth_coord is not None else + {"bottom": 0, "top": 0, "left": 1, "right": 1}[loc]) super().__init__() - - self.passthru_pt = self._default_passthru_pt[loc] - - _verts = np.array([[0., 0.], - [1., 1.]]) - fixed_coord = 1 - nth_coord - _verts[:, fixed_coord] = self.passthru_pt[fixed_coord] - # axis line in transAxes - self._path = Path(_verts) + self._path = Path(self._to_xy((0, 1), const=self._pos)) def get_nth_coord(self): return self.nth_coord @@ -208,14 +216,13 @@ def get_tick_iterators(self, axes): tick_to_axes = self.get_tick_transform(axes) - axes.transAxes def _f(locs, labels): - for x, l in zip(locs, labels): - c = list(self.passthru_pt) # copy - c[self.nth_coord] = x + for loc, label in zip(locs, labels): + c = self._to_xy(loc, const=self._pos) # check if the tick point is inside axes c2 = tick_to_axes.transform(c) if mpl.transforms._interval_contains_close( (0, 1), c2[self.nth_coord]): - yield c, angle_normal, angle_tangent, l + yield c, angle_normal, angle_tangent, label return _f(major_locs, major_labels), _f(minor_locs, minor_labels) @@ -227,15 +234,10 @@ def __init__(self, axes, nth_coord, self.axis = [axes.xaxis, axes.yaxis][self.nth_coord] def get_line(self, axes): - _verts = np.array([[0., 0.], - [1., 1.]]) - fixed_coord = 1 - self.nth_coord data_to_axes = axes.transData - axes.transAxes p = data_to_axes.transform([self._value, self._value]) - _verts[:, fixed_coord] = p[fixed_coord] - - return Path(_verts) + return Path(self._to_xy((0, 1), const=p[fixed_coord])) def get_line_transform(self, axes): return axes.transAxes @@ -250,13 +252,12 @@ def get_axislabel_pos_angle(self, axes): get_label_transform() returns a transform of (transAxes+offset) """ angle = [0, 90][self.nth_coord] - _verts = [0.5, 0.5] fixed_coord = 1 - self.nth_coord data_to_axes = axes.transData - axes.transAxes p = data_to_axes.transform([self._value, self._value]) - _verts[fixed_coord] = p[fixed_coord] - if 0 <= _verts[fixed_coord] <= 1: - return _verts, angle + verts = self._to_xy(0.5, const=p[fixed_coord]) + if 0 <= verts[fixed_coord] <= 1: + return verts, angle else: return None, None @@ -281,12 +282,11 @@ def get_tick_iterators(self, axes): data_to_axes = axes.transData - axes.transAxes def _f(locs, labels): - for x, l in zip(locs, labels): - c = [self._value, self._value] - c[self.nth_coord] = x + for loc, label in zip(locs, labels): + c = self._to_xy(loc, const=self._value) c1, c2 = data_to_axes.transform(c) if 0 <= c1 <= 1 and 0 <= c2 <= 1: - yield c, angle_normal, angle_tangent, l + yield c, angle_normal, angle_tangent, label return _f(major_locs, major_labels), _f(minor_locs, minor_labels)