Skip to content

Commit a786018

Browse files
committed
Apply unit decorator to more functions
Add some more unit decorators Add unit decorator to mplot3d
1 parent 50a8da5 commit a786018

File tree

4 files changed

+39
-73
lines changed

4 files changed

+39
-73
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 17 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ def legend(self, *args, **kwargs):
385385
def _remove_legend(self, legend):
386386
self.legend_ = None
387387

388+
@munits._accepts_units(convert_x=['x'], convert_y=['y'])
388389
def text(self, x, y, s, fontdict=None, withdash=False, **kwargs):
389390
"""
390391
Add text to the axes.
@@ -619,6 +620,8 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs):
619620
self.autoscale_view(scalex=scalex, scaley=False)
620621
return l
621622

623+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
624+
convert_y=['ymin', 'ymax'])
622625
@docstring.dedent_interpd
623626
def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
624627
"""
@@ -660,21 +663,15 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
660663
axvspan : Add a vertical span across the axes.
661664
"""
662665
trans = self.get_yaxis_transform(which='grid')
663-
664-
# process the unit information
665-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
666-
667-
# first we need to strip away the units
668-
xmin, xmax = self.convert_xunits([xmin, xmax])
669-
ymin, ymax = self.convert_yunits([ymin, ymax])
670-
671666
verts = (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)
672667
p = mpatches.Polygon(verts, **kwargs)
673668
p.set_transform(trans)
674669
self.add_patch(p)
675670
self.autoscale_view(scalex=False)
676671
return p
677672

673+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
674+
convert_y=['ymin', 'ymax'])
678675
def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
679676
"""
680677
Add a vertical span (rectangle) across the axes.
@@ -725,21 +722,14 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
725722
726723
"""
727724
trans = self.get_xaxis_transform(which='grid')
728-
729-
# process the unit information
730-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
731-
732-
# first we need to strip away the units
733-
xmin, xmax = self.convert_xunits([xmin, xmax])
734-
ymin, ymax = self.convert_yunits([ymin, ymax])
735-
736725
verts = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)]
737726
p = mpatches.Polygon(verts, **kwargs)
738727
p.set_transform(trans)
739728
self.add_patch(p)
740729
self.autoscale_view(scaley=False)
741730
return p
742731

732+
@munits._accepts_units(convert_x=['xmin', 'xmax'], convert_y=['y'])
743733
@_preprocess_data(replace_names=["y", "xmin", "xmax", "colors"],
744734
label_namer="y")
745735
def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
@@ -775,14 +765,6 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
775765
vlines : vertical lines
776766
axhline: horizontal line across the axes
777767
"""
778-
779-
# We do the conversion first since not all unitized data is uniform
780-
# process the unit information
781-
self._process_unit_info([xmin, xmax], y, kwargs=kwargs)
782-
y = self.convert_yunits(y)
783-
xmin = self.convert_xunits(xmin)
784-
xmax = self.convert_xunits(xmax)
785-
786768
if not iterable(y):
787769
y = [y]
788770
if not iterable(xmin):
@@ -816,6 +798,7 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
816798

817799
return lines
818800

801+
@munits._accepts_units(convert_x=['x'], convert_y=['ymin', 'ymax'])
819802
@_preprocess_data(replace_names=["x", "ymin", "ymax", "colors"],
820803
label_namer="x")
821804
def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
@@ -853,14 +836,6 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
853836
hlines : horizontal lines
854837
axvline: vertical line across the axes
855838
"""
856-
857-
self._process_unit_info(xdata=x, ydata=[ymin, ymax], kwargs=kwargs)
858-
859-
# We do the conversion first since not all unitized data is uniform
860-
x = self.convert_xunits(x)
861-
ymin = self.convert_yunits(ymin)
862-
ymax = self.convert_yunits(ymax)
863-
864839
if not iterable(x):
865840
x = [x]
866841
if not iterable(ymin):
@@ -893,6 +868,8 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
893868

894869
return lines
895870

871+
@munits._accepts_units(convert_x=['positions'],
872+
convert_y=['lineoffsets', 'linelengths'])
896873
@_preprocess_data(replace_names=["positions", "lineoffsets",
897874
"linelengths", "linewidths",
898875
"colors", "linestyles"],
@@ -982,15 +959,6 @@ def eventplot(self, positions, orientation='horizontal', lineoffsets=1,
982959
983960
.. plot:: gallery/lines_bars_and_markers/eventplot_demo.py
984961
"""
985-
self._process_unit_info(xdata=positions,
986-
ydata=[lineoffsets, linelengths],
987-
kwargs=kwargs)
988-
989-
# We do the conversion first since not all unitized data is uniform
990-
positions = self.convert_xunits(positions)
991-
lineoffsets = self.convert_yunits(lineoffsets)
992-
linelengths = self.convert_yunits(linelengths)
993-
994962
if not iterable(positions):
995963
positions = [positions]
996964
elif any(iterable(position) for position in positions):
@@ -4628,6 +4596,7 @@ def fill(self, *args, **kwargs):
46284596
self.autoscale_view()
46294597
return patches
46304598

4599+
@munits._accepts_units(convert_x=['x'], convert_y=['y1', 'y2'])
46314600
@_preprocess_data(replace_names=["x", "y1", "y2", "where"],
46324601
label_namer=None)
46334602
@docstring.dedent_interpd
@@ -4721,14 +4690,10 @@ def fill_between(self, x, y1, y2=0, where=None, interpolate=False,
47214690
kwargs['facecolor'] = \
47224691
self._get_patches_for_fill.get_next_color()
47234692

4724-
# Handle united data, such as dates
4725-
self._process_unit_info(xdata=x, ydata=y1, kwargs=kwargs)
4726-
self._process_unit_info(ydata=y2)
4727-
47284693
# Convert the arrays so we can work with them
4729-
x = ma.masked_invalid(self.convert_xunits(x))
4730-
y1 = ma.masked_invalid(self.convert_yunits(y1))
4731-
y2 = ma.masked_invalid(self.convert_yunits(y2))
4694+
x = ma.masked_invalid(x)
4695+
y1 = ma.masked_invalid(y1)
4696+
y2 = ma.masked_invalid(y2)
47324697

47334698
for name, array in [('x', x), ('y1', y1), ('y2', y2)]:
47344699
if array.ndim > 1:
@@ -4811,6 +4776,7 @@ def get_interp_point(ind):
48114776
self.autoscale_view()
48124777
return collection
48134778

4779+
@munits._accepts_units(convert_x=['x1', 'x2'], convert_y=['y'])
48144780
@_preprocess_data(replace_names=["y", "x1", "x2", "where"],
48154781
label_namer=None)
48164782
@docstring.dedent_interpd
@@ -4904,14 +4870,10 @@ def fill_betweenx(self, y, x1, x2=0, where=None,
49044870
kwargs['facecolor'] = \
49054871
self._get_patches_for_fill.get_next_color()
49064872

4907-
# Handle united data, such as dates
4908-
self._process_unit_info(ydata=y, xdata=x1, kwargs=kwargs)
4909-
self._process_unit_info(xdata=x2)
4910-
49114873
# Convert the arrays so we can work with them
4912-
y = ma.masked_invalid(self.convert_yunits(y))
4913-
x1 = ma.masked_invalid(self.convert_xunits(x1))
4914-
x2 = ma.masked_invalid(self.convert_xunits(x2))
4874+
y = ma.masked_invalid(y)
4875+
x1 = ma.masked_invalid(x1)
4876+
x2 = ma.masked_invalid(x2)
49154877

49164878
for name, array in [('y', y), ('x1', x1), ('x2', x2)]:
49174879
if array.ndim > 1:

lib/matplotlib/axes/_base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import matplotlib.font_manager as font_manager
2727
import matplotlib.text as mtext
2828
import matplotlib.image as mimage
29+
import matplotlib.units as munits
2930
from matplotlib.offsetbox import OffsetBox
3031
from matplotlib.artist import allow_rasterization
3132
from matplotlib.legend import Legend
@@ -3007,20 +3008,19 @@ def get_xlim(self):
30073008
"""
30083009
return tuple(self.viewLim.intervalx)
30093010

3010-
def _validate_converted_limits(self, limit, convert):
3011+
def _validate_converted_limits(self, converted_limit):
30113012
"""
30123013
Raise ValueError if converted limits are non-finite.
30133014
30143015
Note that this function also accepts None as a limit argument.
30153016
"""
3016-
if limit is not None:
3017-
converted_limit = convert(limit)
3017+
if converted_limit is not None:
30183018
if (isinstance(converted_limit, float) and
30193019
(not np.isreal(converted_limit) or
30203020
not np.isfinite(converted_limit))):
30213021
raise ValueError("Axis limits cannot be NaN or Inf")
3022-
return converted_limit
30233022

3023+
@munits._accepts_units(convert_x=['left', 'right'])
30243024
def set_xlim(self, left=None, right=None, emit=True, auto=False, **kw):
30253025
"""
30263026
Set the data limits for the x-axis
@@ -3088,9 +3088,8 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False, **kw):
30883088
if right is None and iterable(left):
30893089
left, right = left
30903090

3091-
self._process_unit_info(xdata=(left, right))
3092-
left = self._validate_converted_limits(left, self.convert_xunits)
3093-
right = self._validate_converted_limits(right, self.convert_xunits)
3091+
self._validate_converted_limits(left)
3092+
self._validate_converted_limits(right)
30943093

30953094
old_left, old_right = self.get_xlim()
30963095
if left is None:
@@ -3351,6 +3350,7 @@ def get_ylim(self):
33513350
"""
33523351
return tuple(self.viewLim.intervaly)
33533352

3353+
@munits._accepts_units(convert_y=['bottom', 'top'])
33543354
def set_ylim(self, bottom=None, top=None, emit=True, auto=False, **kw):
33553355
"""
33563356
Set the data limits for the y-axis
@@ -3417,8 +3417,8 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False, **kw):
34173417
if top is None and iterable(bottom):
34183418
bottom, top = bottom
34193419

3420-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
3421-
top = self._validate_converted_limits(top, self.convert_yunits)
3420+
self._validate_converted_limits(bottom)
3421+
self._validate_converted_limits(top)
34223422

34233423
old_bottom, old_top = self.get_ylim()
34243424

lib/matplotlib/units.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def default_units(x, axis):
5151
from matplotlib.cbook import iterable, safe_first_element
5252

5353

54-
def _accepts_units(convert_x, convert_y):
54+
def _accepts_units(convert_x=[], convert_y=[]):
5555
"""
5656
A decorator for functions and methods that accept units. The parameters
5757
indicated in *convert_x* and *convert_y* are used to update the axis
@@ -69,6 +69,7 @@ def wrapper(*args, **kwargs):
6969
axes = args[0]
7070
# Bind the incoming arguments to the function signature
7171
bound_args = inspect.signature(func).bind(*args, **kwargs)
72+
bound_args.apply_defaults()
7273
# Get the original arguments - these will be modified later
7374
arguments = bound_args.arguments
7475
# Check for data kwarg

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import matplotlib.docstring as docstring
2727
import matplotlib.scale as mscale
2828
import matplotlib.transforms as mtransforms
29+
import matplotlib.units as munits
2930
from matplotlib.axes import Axes, rcParams
3031
from matplotlib.colors import Normalize, LightSource
3132
from matplotlib.transforms import Bbox
@@ -603,6 +604,7 @@ def _determine_lims(self, xmin=None, xmax=None, *args, **kwargs):
603604
xmax += 0.05
604605
return (xmin, xmax)
605606

607+
@munits._accepts_units(convert_x=['left', 'right'])
606608
def set_xlim3d(self, left=None, right=None, emit=True, auto=False, **kw):
607609
"""
608610
Set 3D x limits.
@@ -620,9 +622,8 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False, **kw):
620622
if right is None and cbook.iterable(left):
621623
left, right = left
622624

623-
self._process_unit_info(xdata=(left, right))
624-
left = self._validate_converted_limits(left, self.convert_xunits)
625-
right = self._validate_converted_limits(right, self.convert_xunits)
625+
self._validate_converted_limits(left)
626+
self._validate_converted_limits(right)
626627

627628
old_left, old_right = self.get_xlim()
628629
if left is None:
@@ -655,6 +656,7 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False, **kw):
655656
return left, right
656657
set_xlim = set_xlim3d
657658

659+
@munits._accepts_units(convert_y=['bottom', 'top'])
658660
def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False, **kw):
659661
"""
660662
Set 3D y limits.
@@ -672,9 +674,8 @@ def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False, **kw):
672674
if top is None and cbook.iterable(bottom):
673675
bottom, top = bottom
674676

675-
self._process_unit_info(ydata=(bottom, top))
676-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
677-
top = self._validate_converted_limits(top, self.convert_yunits)
677+
self._validate_converted_limits(bottom)
678+
self._validate_converted_limits(top)
678679

679680
old_bottom, old_top = self.get_ylim()
680681
if bottom is None:
@@ -725,8 +726,10 @@ def set_zlim3d(self, bottom=None, top=None, emit=True, auto=False, **kw):
725726
bottom, top = bottom
726727

727728
self._process_unit_info(zdata=(bottom, top))
728-
bottom = self._validate_converted_limits(bottom, self.convert_zunits)
729-
top = self._validate_converted_limits(top, self.convert_zunits)
729+
bottom = self.convert_zunits(bottom)
730+
top = self.convert_zunits(top)
731+
self._validate_converted_limits(bottom)
732+
self._validate_converted_limits(top)
730733

731734
old_bottom, old_top = self.get_zlim()
732735
if bottom is None:

0 commit comments

Comments
 (0)