Skip to content

Commit 94b3b91

Browse files
committed
Apply unit decorator to more functions
Add some more unit decorators Add unit decorator to mplot3d
1 parent aeb67db commit 94b3b91

File tree

4 files changed

+42
-75
lines changed

4 files changed

+42
-75
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 17 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,7 @@ def legend(self, *args, **kwargs):
385385
def _remove_legend(self, legend):
386386
self.legend_ = None
387387

388+
@munits._accepts_units(convert_x=['x'], convert_y=['y'])
388389
def text(self, x, y, s, fontdict=None, withdash=False, **kwargs):
389390
"""
390391
Add text to the axes.
@@ -619,6 +620,8 @@ def axvline(self, x=0, ymin=0, ymax=1, **kwargs):
619620
self.autoscale_view(scalex=scalex, scaley=False)
620621
return l
621622

623+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
624+
convert_y=['ymin', 'ymax'])
622625
@docstring.dedent_interpd
623626
def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
624627
"""
@@ -660,21 +663,15 @@ def axhspan(self, ymin, ymax, xmin=0, xmax=1, **kwargs):
660663
axvspan : Add a vertical span across the axes.
661664
"""
662665
trans = self.get_yaxis_transform(which='grid')
663-
664-
# process the unit information
665-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
666-
667-
# first we need to strip away the units
668-
xmin, xmax = self.convert_xunits([xmin, xmax])
669-
ymin, ymax = self.convert_yunits([ymin, ymax])
670-
671666
verts = (xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)
672667
p = mpatches.Polygon(verts, **kwargs)
673668
p.set_transform(trans)
674669
self.add_patch(p)
675670
self.autoscale_view(scalex=False)
676671
return p
677672

673+
@munits._accepts_units(convert_x=['xmin', 'xmax'],
674+
convert_y=['ymin', 'ymax'])
678675
def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
679676
"""
680677
Add a vertical span (rectangle) across the axes.
@@ -725,21 +722,14 @@ def axvspan(self, xmin, xmax, ymin=0, ymax=1, **kwargs):
725722
726723
"""
727724
trans = self.get_xaxis_transform(which='grid')
728-
729-
# process the unit information
730-
self._process_unit_info([xmin, xmax], [ymin, ymax], kwargs=kwargs)
731-
732-
# first we need to strip away the units
733-
xmin, xmax = self.convert_xunits([xmin, xmax])
734-
ymin, ymax = self.convert_yunits([ymin, ymax])
735-
736725
verts = [(xmin, ymin), (xmin, ymax), (xmax, ymax), (xmax, ymin)]
737726
p = mpatches.Polygon(verts, **kwargs)
738727
p.set_transform(trans)
739728
self.add_patch(p)
740729
self.autoscale_view(scaley=False)
741730
return p
742731

732+
@munits._accepts_units(convert_x=['xmin', 'xmax'], convert_y=['y'])
743733
@_preprocess_data(replace_names=["y", "xmin", "xmax", "colors"],
744734
label_namer="y")
745735
def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
@@ -775,14 +765,6 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
775765
vlines : vertical lines
776766
axhline: horizontal line across the axes
777767
"""
778-
779-
# We do the conversion first since not all unitized data is uniform
780-
# process the unit information
781-
self._process_unit_info([xmin, xmax], y, kwargs=kwargs)
782-
y = self.convert_yunits(y)
783-
xmin = self.convert_xunits(xmin)
784-
xmax = self.convert_xunits(xmax)
785-
786768
if not iterable(y):
787769
y = [y]
788770
if not iterable(xmin):
@@ -816,6 +798,7 @@ def hlines(self, y, xmin, xmax, colors='k', linestyles='solid',
816798

817799
return lines
818800

801+
@munits._accepts_units(convert_x=['x'], convert_y=['ymin', 'ymax'])
819802
@_preprocess_data(replace_names=["x", "ymin", "ymax", "colors"],
820803
label_namer="x")
821804
def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
@@ -853,14 +836,6 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
853836
hlines : horizontal lines
854837
axvline: vertical line across the axes
855838
"""
856-
857-
self._process_unit_info(xdata=x, ydata=[ymin, ymax], kwargs=kwargs)
858-
859-
# We do the conversion first since not all unitized data is uniform
860-
x = self.convert_xunits(x)
861-
ymin = self.convert_yunits(ymin)
862-
ymax = self.convert_yunits(ymax)
863-
864839
if not iterable(x):
865840
x = [x]
866841
if not iterable(ymin):
@@ -893,6 +868,8 @@ def vlines(self, x, ymin, ymax, colors='k', linestyles='solid',
893868

894869
return lines
895870

871+
@munits._accepts_units(convert_x=['positions'],
872+
convert_y=['lineoffsets', 'linelengths'])
896873
@_preprocess_data(replace_names=["positions", "lineoffsets",
897874
"linelengths", "linewidths",
898875
"colors", "linestyles"],
@@ -982,15 +959,6 @@ def eventplot(self, positions, orientation='horizontal', lineoffsets=1,
982959
983960
.. plot:: gallery/lines_bars_and_markers/eventplot_demo.py
984961
"""
985-
self._process_unit_info(xdata=positions,
986-
ydata=[lineoffsets, linelengths],
987-
kwargs=kwargs)
988-
989-
# We do the conversion first since not all unitized data is uniform
990-
positions = self.convert_xunits(positions)
991-
lineoffsets = self.convert_yunits(lineoffsets)
992-
linelengths = self.convert_yunits(linelengths)
993-
994962
if not iterable(positions):
995963
positions = [positions]
996964
elif any(iterable(position) for position in positions):
@@ -4627,6 +4595,7 @@ def fill(self, *args, **kwargs):
46274595
self.autoscale_view()
46284596
return patches
46294597

4598+
@munits._accepts_units(convert_x=['x'], convert_y=['y1', 'y2'])
46304599
@_preprocess_data(replace_names=["x", "y1", "y2", "where"],
46314600
label_namer=None)
46324601
@docstring.dedent_interpd
@@ -4720,14 +4689,10 @@ def fill_between(self, x, y1, y2=0, where=None, interpolate=False,
47204689
kwargs['facecolor'] = \
47214690
self._get_patches_for_fill.get_next_color()
47224691

4723-
# Handle united data, such as dates
4724-
self._process_unit_info(xdata=x, ydata=y1, kwargs=kwargs)
4725-
self._process_unit_info(ydata=y2)
4726-
47274692
# Convert the arrays so we can work with them
4728-
x = ma.masked_invalid(self.convert_xunits(x))
4729-
y1 = ma.masked_invalid(self.convert_yunits(y1))
4730-
y2 = ma.masked_invalid(self.convert_yunits(y2))
4693+
x = ma.masked_invalid(x)
4694+
y1 = ma.masked_invalid(y1)
4695+
y2 = ma.masked_invalid(y2)
47314696

47324697
for name, array in [('x', x), ('y1', y1), ('y2', y2)]:
47334698
if array.ndim > 1:
@@ -4810,6 +4775,7 @@ def get_interp_point(ind):
48104775
self.autoscale_view()
48114776
return collection
48124777

4778+
@munits._accepts_units(convert_x=['x1', 'x2'], convert_y=['y'])
48134779
@_preprocess_data(replace_names=["y", "x1", "x2", "where"],
48144780
label_namer=None)
48154781
@docstring.dedent_interpd
@@ -4903,14 +4869,10 @@ def fill_betweenx(self, y, x1, x2=0, where=None,
49034869
kwargs['facecolor'] = \
49044870
self._get_patches_for_fill.get_next_color()
49054871

4906-
# Handle united data, such as dates
4907-
self._process_unit_info(ydata=y, xdata=x1, kwargs=kwargs)
4908-
self._process_unit_info(xdata=x2)
4909-
49104872
# Convert the arrays so we can work with them
4911-
y = ma.masked_invalid(self.convert_yunits(y))
4912-
x1 = ma.masked_invalid(self.convert_xunits(x1))
4913-
x2 = ma.masked_invalid(self.convert_xunits(x2))
4873+
y = ma.masked_invalid(y)
4874+
x1 = ma.masked_invalid(x1)
4875+
x2 = ma.masked_invalid(x2)
49144876

49154877
for name, array in [('y', y), ('x1', x1), ('x2', x2)]:
49164878
if array.ndim > 1:

lib/matplotlib/axes/_base.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import matplotlib.font_manager as font_manager
2828
import matplotlib.text as mtext
2929
import matplotlib.image as mimage
30+
import matplotlib.units as munits
3031
from matplotlib.offsetbox import OffsetBox
3132
from matplotlib.artist import allow_rasterization
3233
from matplotlib.legend import Legend
@@ -3001,19 +3002,19 @@ def get_xlim(self):
30013002
"""
30023003
return tuple(self.viewLim.intervalx)
30033004

3004-
def _validate_converted_limits(self, limit, convert):
3005+
def _validate_converted_limits(self, converted_limit):
30053006
"""
30063007
Raise ValueError if converted limits are non-finite.
30073008
30083009
Note that this function also accepts None as a limit argument.
30093010
"""
3010-
if limit is not None:
3011-
converted_limit = convert(limit)
3012-
if (isinstance(converted_limit, Real)
3013-
and not np.isfinite(converted_limit)):
3011+
if converted_limit is not None:
3012+
if (isinstance(converted_limit, float) and
3013+
(not np.isreal(converted_limit) or
3014+
not np.isfinite(converted_limit))):
30143015
raise ValueError("Axis limits cannot be NaN or Inf")
3015-
return converted_limit
30163016

3017+
@munits._accepts_units(convert_x=['left', 'right'])
30173018
def set_xlim(self, left=None, right=None, emit=True, auto=False,
30183019
*, xmin=None, xmax=None):
30193020
"""
@@ -3091,9 +3092,8 @@ def set_xlim(self, left=None, right=None, emit=True, auto=False,
30913092
raise TypeError('Cannot pass both `xmax` and `right`')
30923093
right = xmax
30933094

3094-
self._process_unit_info(xdata=(left, right))
3095-
left = self._validate_converted_limits(left, self.convert_xunits)
3096-
right = self._validate_converted_limits(right, self.convert_xunits)
3095+
self._validate_converted_limits(left)
3096+
self._validate_converted_limits(right)
30973097

30983098
old_left, old_right = self.get_xlim()
30993099
if left is None:
@@ -3348,6 +3348,7 @@ def get_ylim(self):
33483348
"""
33493349
return tuple(self.viewLim.intervaly)
33503350

3351+
@munits._accepts_units(convert_y=['bottom', 'top'])
33513352
def set_ylim(self, bottom=None, top=None, emit=True, auto=False,
33523353
*, ymin=None, ymax=None):
33533354
"""
@@ -3424,8 +3425,8 @@ def set_ylim(self, bottom=None, top=None, emit=True, auto=False,
34243425
raise TypeError('Cannot pass both `ymax` and `top`')
34253426
top = ymax
34263427

3427-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
3428-
top = self._validate_converted_limits(top, self.convert_yunits)
3428+
self._validate_converted_limits(bottom)
3429+
self._validate_converted_limits(top)
34293430

34303431
old_bottom, old_top = self.get_ylim()
34313432

lib/matplotlib/units.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def default_units(x, axis):
5151
from matplotlib.cbook import iterable, safe_first_element
5252

5353

54-
def _accepts_units(convert_x, convert_y):
54+
def _accepts_units(convert_x=[], convert_y=[]):
5555
"""
5656
A decorator for functions and methods that accept units. The parameters
5757
indicated in *convert_x* and *convert_y* are used to update the axis
@@ -69,6 +69,7 @@ def wrapper(*args, **kwargs):
6969
axes = args[0]
7070
# Bind the incoming arguments to the function signature
7171
bound_args = inspect.signature(func).bind(*args, **kwargs)
72+
bound_args.apply_defaults()
7273
# Get the original arguments - these will be modified later
7374
arguments = bound_args.arguments
7475
# Check for data kwarg

lib/mpl_toolkits/mplot3d/axes3d.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import matplotlib.docstring as docstring
2525
import matplotlib.scale as mscale
2626
import matplotlib.transforms as mtransforms
27+
import matplotlib.units as munits
2728
from matplotlib.axes import Axes, rcParams
2829
from matplotlib.colors import Normalize, LightSource
2930
from matplotlib.transforms import Bbox
@@ -593,6 +594,7 @@ def _determine_lims(self, xmin=None, xmax=None, *args, **kwargs):
593594
xmax += 0.05
594595
return (xmin, xmax)
595596

597+
@munits._accepts_units(convert_x=['left', 'right'])
596598
def set_xlim3d(self, left=None, right=None, emit=True, auto=False,
597599
*, xmin=None, xmax=None):
598600
"""
@@ -616,9 +618,8 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False,
616618
raise TypeError('Cannot pass both `xmax` and `right`')
617619
right = xmax
618620

619-
self._process_unit_info(xdata=(left, right))
620-
left = self._validate_converted_limits(left, self.convert_xunits)
621-
right = self._validate_converted_limits(right, self.convert_xunits)
621+
self._validate_converted_limits(left)
622+
self._validate_converted_limits(right)
622623

623624
old_left, old_right = self.get_xlim()
624625
if left is None:
@@ -651,6 +652,7 @@ def set_xlim3d(self, left=None, right=None, emit=True, auto=False,
651652
return left, right
652653
set_xlim = set_xlim3d
653654

655+
@munits._accepts_units(convert_y=['bottom', 'top'])
654656
def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False,
655657
*, ymin=None, ymax=None):
656658
"""
@@ -674,9 +676,8 @@ def set_ylim3d(self, bottom=None, top=None, emit=True, auto=False,
674676
raise TypeError('Cannot pass both `ymax` and `top`')
675677
top = ymax
676678

677-
self._process_unit_info(ydata=(bottom, top))
678-
bottom = self._validate_converted_limits(bottom, self.convert_yunits)
679-
top = self._validate_converted_limits(top, self.convert_yunits)
679+
self._validate_converted_limits(bottom)
680+
self._validate_converted_limits(top)
680681

681682
old_bottom, old_top = self.get_ylim()
682683
if bottom is None:
@@ -733,8 +734,10 @@ def set_zlim3d(self, bottom=None, top=None, emit=True, auto=False,
733734
top = zmax
734735

735736
self._process_unit_info(zdata=(bottom, top))
736-
bottom = self._validate_converted_limits(bottom, self.convert_zunits)
737-
top = self._validate_converted_limits(top, self.convert_zunits)
737+
bottom = self.convert_zunits(bottom)
738+
top = self.convert_zunits(top)
739+
self._validate_converted_limits(bottom)
740+
self._validate_converted_limits(top)
738741

739742
old_bottom, old_top = self.get_zlim()
740743
if bottom is None:

0 commit comments

Comments
 (0)