Skip to content

New "accepts units" decorator #10411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 28 additions & 80 deletions lib/matplotlib/axes/_axes.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import matplotlib.ticker as mticker
import matplotlib.transforms as mtransforms
import matplotlib.tri as mtri
import matplotlib.units as munits
from matplotlib.container import BarContainer, ErrorbarContainer, StemContainer
from matplotlib.axes._base import _AxesBase, _process_plot_format

Expand Down Expand Up @@ -636,6 +637,7 @@ def indicate_inset_zoom(self, inset_ax, **kwargs):

return rectpatch, connects

@munits._accepts_units(convert_x=['x'], convert_y=['y'])
def text(self, x, y, s, fontdict=None, withdash=False, **kwargs):
"""
Add text to the axes.
Expand Down Expand Up @@ -731,6 +733,7 @@ def annotate(self, text, xy, *args, **kwargs):
annotate.__doc__ = mtext.Annotation.__init__.__doc__
#### Lines and spans

@munits._accepts_units(convert_y=['y'])
@docstring.dedent_interpd
def axhline(self, y=0, xmin=0, xmax=1, **kwargs):
"""
Expand Down Expand Up @@ -786,21 +789,17 @@ def axhline(self, y=0, xmin=0, xmax=1, **kwargs):
if "transform" in kwargs:
raise ValueError(
"'transform' is not allowed as a kwarg;"
+ "axhline generates its own transform.")
"axhline generates its own transform.")
ymin, ymax = self.get_ybound()

# We need to strip away the units for comparison with
# non-unitized bounds
self._process_unit_info(ydata=y, kwargs=kwargs)
yy = self.convert_yunits(y)
scaley = (yy < ymin) or (yy > ymax)
scaley = (y < ymin) or (y > ymax)

trans = self.get_yaxis_transform(which='grid')
l = mlines.Line2D([xmin, xmax], [y, y], transform=trans, **kwargs)
self.add_line(l)
self.autoscale_view(scalex=False, scaley=scaley)
return l

@munits._accepts_units(convert_x=['x'])
@docstring.dedent_interpd
def axvline(self, x=0, ymin=0, ymax=1, **kwargs):
"""
Expand Down Expand Up @@ -855,21 +854,18 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs):
if "transform" in kwargs:
raise ValueError(
"'transform' is not allowed as a kwarg;"
+ "axvline generates its own transform.")
"axvline generates its own transform.")
xmin, xmax = self.get_xbound()

# We need to strip away the units for comparison with
# non-unitized bounds
self._process_unit_info(xdata=x, kwargs=kwargs)
xx = self.convert_xunits(x)
scalex = (xx < xmin) or (xx > xmax)
scalex = (x < xmin) or (x > xmax)

trans = self.get_xaxis_transform(which='grid')
l = mlines.Line2D([x, x], [ymin, ymax], transform=trans, **kwargs)
self.add_line(l)
self.autoscale_view(scalex=scalex, scaley=False)
return l

@munits._accepts_units(convert_x=['xmin', 'xmax'],
convert_y=['ymin', 'ymax'])
@docstring.dedent_interpd
def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
"""
Expand Down Expand Up @@ -911,21 +907,15 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
axvspan : Add a vertical span across the axes.
"""
trans = self.get_yaxis_transform(which='grid')

# process the unit information
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)

# first we need to strip away the units
xmin, xmax = self.convert_xunits([xmin, xmax])
ymin, ymax = self.convert_yunits([ymin, ymax])

verts = (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)
p = mpatches.Polygon(verts, **kwargs)
p.set_transform(trans)
self.add_patch(p)
self.autoscale_view(scalex=False)
return p

@munits._accepts_units(convert_x=['xmin', 'xmax'],
convert_y=['ymin', 'ymax'])
def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
"""
Add a vertical span (rectangle) across the axes.
Expand Down Expand Up @@ -976,21 +966,14 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):

"""
trans = self.get_xaxis_transform(which='grid')

# process the unit information
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)

# first we need to strip away the units
xmin, xmax = self.convert_xunits([xmin, xmax])
ymin, ymax = self.convert_yunits([ymin, ymax])

verts = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)]
p = mpatches.Polygon(verts, **kwargs)
p.set_transform(trans)
self.add_patch(p)
self.autoscale_view(scaley=False)
return p

@munits._accepts_units(convert_x=['xmin', 'xmax'], convert_y=['y'])
@_preprocess_data(replace_names=["y", "xmin", "xmax", "colors"],
label_namer="y")
def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
Expand Down Expand Up @@ -1026,14 +1009,6 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
vlines : vertical lines
axhline: horizontal line across the axes
"""

# We do the conversion first since not all unitized data is uniform
# process the unit information
self._process_unit_info([xmin, xmax], y, kwargs=kwargs)
y = self.convert_yunits(y)
xmin = self.convert_xunits(xmin)
xmax = self.convert_xunits(xmax)

if not np.iterable(y):
y = [y]
if not np.iterable(xmin):
Expand Down Expand Up @@ -1067,6 +1042,7 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',

return lines

@munits._accepts_units(convert_x=['x'], convert_y=['ymin', 'ymax'])
@_preprocess_data(replace_names=["x", "ymin", "ymax", "colors"],
label_namer="x")
def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
Expand Down Expand Up @@ -1104,14 +1080,6 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
hlines : horizontal lines
axvline: vertical line across the axes
"""

self._process_unit_info(xdata=x, ydata=[ymin, ymax], kwargs=kwargs)

# We do the conversion first since not all unitized data is uniform
x = self.convert_xunits(x)
ymin = self.convert_yunits(ymin)
ymax = self.convert_yunits(ymax)

if not np.iterable(x):
x = [x]
if not np.iterable(ymin):
Expand Down Expand Up @@ -1144,6 +1112,8 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',

return lines

@munits._accepts_units(convert_x=['positions'],
convert_y=['lineoffsets', 'linelengths'])
@_preprocess_data(replace_names=["positions", "lineoffsets",
"linelengths", "linewidths",
"colors", "linestyles"],
Expand Down Expand Up @@ -1233,15 +1203,6 @@ def eventplot(self, positions, orientation='horizontal', lineoffsets=1,

.. plot:: gallery/lines_bars_and_markers/eventplot_demo.py
"""
self._process_unit_info(xdata=positions,
ydata=[lineoffsets, linelengths],
kwargs=kwargs)

# We do the conversion first since not all unitized data is uniform
positions = self.convert_xunits(positions)
lineoffsets = self.convert_yunits(lineoffsets)
linelengths = self.convert_yunits(linelengths)

if not np.iterable(positions):
positions = [positions]
elif any(np.iterable(position) for position in positions):
Expand Down Expand Up @@ -1984,6 +1945,7 @@ def xcorr(self, x, y, normed=True, detrend=mlab.detrend_none,

#### Specialized plotting

@munits._accepts_units(convert_x=['x'], convert_y=['y'])
@_preprocess_data(replace_names=["x", "y"], label_namer="y")
def step(self, x, y, *args, where='pre', **kwargs):
"""
Expand Down Expand Up @@ -2453,6 +2415,7 @@ def barh(self, y, width, height=0.8, left=None, *, align="center",
align=align, **kwargs)
return patches

@munits._accepts_units(convert_x=['xranges'], convert_y=['yrange'])
@_preprocess_data(label_namer=None)
@docstring.dedent_interpd
def broken_barh(self, xranges, yrange, **kwargs):
Expand Down Expand Up @@ -2512,11 +2475,6 @@ def broken_barh(self, xranges, yrange, **kwargs):
ydata = cbook.safe_first_element(yrange)
else:
ydata = None
self._process_unit_info(xdata=xdata,
ydata=ydata,
kwargs=kwargs)
xranges = self.convert_xunits(xranges)
yrange = self.convert_yunits(yrange)

col = mcoll.BrokenBarHCollection(xranges, yrange, **kwargs)
self.add_collection(col, autolim=True)
Expand Down Expand Up @@ -4006,6 +3964,7 @@ def dopatch(xs, ys, **kwargs):
return dict(whiskers=whiskers, caps=caps, boxes=boxes,
medians=medians, fliers=fliers, means=means)

@munits._accepts_units(convert_x=['x'], convert_y=['y'])
@_preprocess_data(replace_names=["x", "y", "s", "linewidths",
"edgecolors", "c", "facecolor",
"facecolors", "color"],
Expand Down Expand Up @@ -4149,10 +4108,6 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,
if edgecolors is None and not rcParams['_internal.classic_mode']:
edgecolors = 'face'

self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs)
x = self.convert_xunits(x)
y = self.convert_yunits(y)

# np.ma.ravel yields an ndarray, not a masked array,
# unless its argument is a masked array.
xy_shape = (np.shape(x), np.shape(y))
Expand Down Expand Up @@ -4303,6 +4258,7 @@ def scatter(self, x, y, s=None, c=None, marker=None, cmap=None, norm=None,

return collection

@munits._accepts_units(convert_x=['x'], convert_y=['y'])
@_preprocess_data(replace_names=["x", "y"], label_namer="y")
@docstring.dedent_interpd
def hexbin(self, x, y, C=None, gridsize=100, bins=None,
Expand Down Expand Up @@ -4431,8 +4387,6 @@ def hexbin(self, x, y, C=None, gridsize=100, bins=None,
%(Collection)s

"""
self._process_unit_info(xdata=x, ydata=y, kwargs=kwargs)

x, y, C = cbook.delete_masked_points(x, y, C)

# Set the size of the hexagon grid
Expand Down Expand Up @@ -4921,6 +4875,7 @@ def fill(self, *args, **kwargs):
self.autoscale_view()
return patches

@munits._accepts_units(convert_x=['x'], convert_y=['y1', 'y2'])
@_preprocess_data(replace_names=["x", "y1", "y2", "where"],
label_namer=None)
@docstring.dedent_interpd
Expand Down Expand Up @@ -5014,14 +4969,10 @@ def fill_between(self, x, y1, y2=0, where=None, interpolate=False,
kwargs['facecolor'] = \
self._get_patches_for_fill.get_next_color()

# Handle united data, such as dates
self._process_unit_info(xdata=x, ydata=y1, kwargs=kwargs)
self._process_unit_info(ydata=y2)

# Convert the arrays so we can work with them
x = ma.masked_invalid(self.convert_xunits(x))
y1 = ma.masked_invalid(self.convert_yunits(y1))
y2 = ma.masked_invalid(self.convert_yunits(y2))
x = ma.masked_invalid(x)
y1 = ma.masked_invalid(y1)
y2 = ma.masked_invalid(y2)

for name, array in [('x', x), ('y1', y1), ('y2', y2)]:
if array.ndim > 1:
Expand Down Expand Up @@ -5104,6 +5055,7 @@ def get_interp_point(ind):
self.autoscale_view()
return collection

@munits._accepts_units(convert_x=['x1', 'x2'], convert_y=['y'])
@_preprocess_data(replace_names=["y", "x1", "x2", "where"],
label_namer=None)
@docstring.dedent_interpd
Expand Down Expand Up @@ -5197,14 +5149,10 @@ def fill_betweenx(self, y, x1, x2=0, where=None,
kwargs['facecolor'] = \
self._get_patches_for_fill.get_next_color()

# Handle united data, such as dates
self._process_unit_info(ydata=y, xdata=x1, kwargs=kwargs)
self._process_unit_info(xdata=x2)

# Convert the arrays so we can work with them
y = ma.masked_invalid(self.convert_yunits(y))
x1 = ma.masked_invalid(self.convert_xunits(x1))
x2 = ma.masked_invalid(self.convert_xunits(x2))
y = ma.masked_invalid(y)
x1 = ma.masked_invalid(x1)
x2 = ma.masked_invalid(x2)

for name, array in [('y', y), ('x1', x1), ('x2', x2)]:
if array.ndim > 1:
Expand Down
28 changes: 12 additions & 16 deletions lib/matplotlib/axes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import matplotlib.font_manager as font_manager
import matplotlib.text as mtext
import matplotlib.image as mimage
import matplotlib.units as munits

from matplotlib.rcsetup import cycler, validate_axisbelow

Expand Down Expand Up @@ -3041,24 +3042,19 @@ def get_xlim(self):
"""
return tuple(self.viewLim.intervalx)

def _validate_converted_limits(self, limit, convert):
def _validate_converted_limits(self, converted_limit):
"""
Raise ValueError if converted limits are non-finite.

Note that this function also accepts None as a limit argument.

Returns
-------
The limit value after call to convert(), or None if limit is None.

"""
if limit is not None:
converted_limit = convert(limit)
if (isinstance(converted_limit, Real)
and not np.isfinite(converted_limit)):
if converted_limit is not None:
if (isinstance(converted_limit, float) and
(not np.isreal(converted_limit) or
not np.isfinite(converted_limit))):
raise ValueError("Axis limits cannot be NaN or Inf")
return converted_limit

@munits._accepts_units(convert_x=['left', 'right'])
def set_xlim(self, left=None, right=None, emit=True, auto=False,
*, xmin=None, xmax=None):
"""
Expand Down Expand Up @@ -3136,9 +3132,8 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False,
raise TypeError('Cannot pass both `xmax` and `right`')
right = xmax

self._process_unit_info(xdata=(left, right))
left = self._validate_converted_limits(left, self.convert_xunits)
right = self._validate_converted_limits(right, self.convert_xunits)
self._validate_converted_limits(left)
self._validate_converted_limits(right)

old_left, old_right = self.get_xlim()
if left is None:
Expand Down Expand Up @@ -3393,6 +3388,7 @@ def get_ylim(self):
"""
return tuple(self.viewLim.intervaly)

@munits._accepts_units(convert_y=['bottom', 'top'])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to note that this fixes a bug for us where we ran into problems setting limits using units. (Note the lack of a call to _process_unit_info() in the original set_ylim() code.)

def set_ylim(self, bottom=None, top=None, emit=True, auto=False,
*, ymin=None, ymax=None):
"""
Expand Down Expand Up @@ -3469,8 +3465,8 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False,
raise TypeError('Cannot pass both `ymax` and `top`')
top = ymax

bottom = self._validate_converted_limits(bottom, self.convert_yunits)
top = self._validate_converted_limits(top, self.convert_yunits)
self._validate_converted_limits(bottom)
self._validate_converted_limits(top)

old_bottom, old_top = self.get_ylim()

Expand Down
Loading