diff --git a/lib/matplotlib/axes/_base.py b/lib/matplotlib/axes/_base.py index 9cf460407511..eeadfc87a1e0 100644 --- a/lib/matplotlib/axes/_base.py +++ b/lib/matplotlib/axes/_base.py @@ -556,12 +556,12 @@ def __init__(self, fig, rect, self.update(kwargs) if self.xaxis is not None: - self._xcid = self.xaxis.callbacks.connect('units finalize', - self.relim) + self._xcid = self.xaxis.callbacks.connect( + 'units finalize', lambda: self._on_units_changed(scalex=True)) if self.yaxis is not None: - self._ycid = self.yaxis.callbacks.connect('units finalize', - self.relim) + self._ycid = self.yaxis.callbacks.connect( + 'units finalize', lambda: self._on_units_changed(scaley=True)) self.tick_params( top=rcParams['xtick.top'] and rcParams['xtick.minor.top'], @@ -1891,6 +1891,15 @@ def add_container(self, container): container.set_remove_method(lambda h: self.containers.remove(h)) return container + def _on_units_changed(self, scalex=False, scaley=False): + """ + Callback for processing changes to axis units. + + Currently forces updates of data limits and view limits. + """ + self.relim() + self.autoscale_view(scalex=scalex, scaley=scaley) + def relim(self, visible_only=False): """ Recompute the data limits based on current artists. If you want to diff --git a/lib/matplotlib/cbook/__init__.py b/lib/matplotlib/cbook/__init__.py index 838081e33f66..36986e1b4210 100644 --- a/lib/matplotlib/cbook/__init__.py +++ b/lib/matplotlib/cbook/__init__.py @@ -1966,6 +1966,17 @@ def is_math_text(s): return even_dollars +def _to_unmasked_float_array(x): + """ + Convert a sequence to a float array; if input was a masked array, masked + values are converted to nans. + """ + if hasattr(x, 'mask'): + return np.ma.asarray(x, float).filled(np.nan) + else: + return np.asarray(x, float) + + def _check_1d(x): ''' Converts a sequence of less than 1 dimension, to an array of 1 @@ -2252,7 +2263,7 @@ def index_of(y): try: return y.index.values, y.values except AttributeError: - y = np.atleast_1d(y) + y = _check_1d(y) return np.arange(y.shape[0], dtype=float), y diff --git a/lib/matplotlib/lines.py b/lib/matplotlib/lines.py index 256e59c2eb2e..9367eb8fc853 100644 --- a/lib/matplotlib/lines.py +++ b/lib/matplotlib/lines.py @@ -16,7 +16,8 @@ from . import artist, colors as mcolors, docstring, rcParams from .artist import Artist, allow_rasterization from .cbook import ( - iterable, is_numlike, ls_mapper, ls_mapper_r, STEP_LOOKUP_MAP) + _to_unmasked_float_array, iterable, is_numlike, ls_mapper, ls_mapper_r, + STEP_LOOKUP_MAP) from .markers import MarkerStyle from .path import Path from .transforms import Bbox, TransformedPath, IdentityTransform @@ -648,20 +649,12 @@ def recache_always(self): def recache(self, always=False): if always or self._invalidx: xconv = self.convert_xunits(self._xorig) - if isinstance(self._xorig, np.ma.MaskedArray): - x = np.ma.asarray(xconv, float).filled(np.nan) - else: - x = np.asarray(xconv, float) - x = x.ravel() + x = _to_unmasked_float_array(xconv).ravel() else: x = self._x if always or self._invalidy: yconv = self.convert_yunits(self._yorig) - if isinstance(self._yorig, np.ma.MaskedArray): - y = np.ma.asarray(yconv, float).filled(np.nan) - else: - y = np.asarray(yconv, float) - y = y.ravel() + y = _to_unmasked_float_array(yconv).ravel() else: y = self._y diff --git a/lib/matplotlib/path.py b/lib/matplotlib/path.py index fcaf191dc04b..54ec0b2fa1fc 100644 --- a/lib/matplotlib/path.py +++ b/lib/matplotlib/path.py @@ -23,7 +23,8 @@ import numpy as np from . import _path, rcParams -from .cbook import simple_linear_interpolation, maxdict +from .cbook import (_to_unmasked_float_array, simple_linear_interpolation, + maxdict) class Path(object): @@ -129,11 +130,7 @@ def __init__(self, vertices, codes=None, _interpolation_steps=1, Makes the path behave in an immutable way and sets the vertices and codes as read-only arrays. """ - if isinstance(vertices, np.ma.MaskedArray): - vertices = vertices.astype(float).filled(np.nan) - else: - vertices = np.asarray(vertices, float) - + vertices = _to_unmasked_float_array(vertices) if (vertices.ndim != 2) or (vertices.shape[1] != 2): msg = "'vertices' must be a 2D list or array with shape Nx2" raise ValueError(msg) @@ -185,11 +182,7 @@ def _fast_from_codes_and_verts(cls, verts, codes, internals=None): """ internals = internals or {} pth = cls.__new__(cls) - if isinstance(verts, np.ma.MaskedArray): - verts = verts.astype(float).filled(np.nan) - else: - verts = np.asarray(verts, float) - pth._vertices = verts + pth._vertices = _to_unmasked_float_array(verts) pth._codes = codes pth._readonly = internals.pop('readonly', False) pth.should_simplify = internals.pop('should_simplify', True) diff --git a/lib/matplotlib/tests/baseline_images/test_units/plot_masked_units.png b/lib/matplotlib/tests/baseline_images/test_units/plot_masked_units.png new file mode 100644 index 000000000000..98a07b2654fe Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_units/plot_masked_units.png differ diff --git a/lib/matplotlib/tests/baseline_images/test_units/plot_pint.png b/lib/matplotlib/tests/baseline_images/test_units/plot_pint.png new file mode 100644 index 000000000000..f15f81fda6f6 Binary files /dev/null and b/lib/matplotlib/tests/baseline_images/test_units/plot_pint.png differ diff --git a/lib/matplotlib/tests/test_units.py b/lib/matplotlib/tests/test_units.py index 1112b0f22574..f72ac2c60476 100644 --- a/lib/matplotlib/tests/test_units.py +++ b/lib/matplotlib/tests/test_units.py @@ -1,4 +1,6 @@ +from matplotlib.cbook import iterable import matplotlib.pyplot as plt +from matplotlib.testing.decorators import image_comparison import matplotlib.units as munits import numpy as np @@ -9,49 +11,84 @@ from mock import MagicMock -# Tests that the conversion machinery works properly for classes that -# work as a facade over numpy arrays (like pint) -def test_numpy_facade(): - # Basic class that wraps numpy array and has units - class Quantity(object): - def __init__(self, data, units): - self.magnitude = data - self.units = units +# Basic class that wraps numpy array and has units +class Quantity(object): + def __init__(self, data, units): + self.magnitude = data + self.units = units + + def to(self, new_units): + factors = {('hours', 'seconds'): 3600, ('minutes', 'hours'): 1 / 60, + ('minutes', 'seconds'): 60, ('feet', 'miles'): 1 / 5280., + ('feet', 'inches'): 12, ('miles', 'inches'): 12 * 5280} + if self.units != new_units: + mult = factors[self.units, new_units] + return Quantity(mult * self.magnitude, new_units) + else: + return Quantity(self.magnitude, self.units) + + def __getattr__(self, attr): + return getattr(self.magnitude, attr) - def to(self, new_units): - return Quantity(self.magnitude, new_units) + def __getitem__(self, item): + return Quantity(self.magnitude[item], self.units) - def __getattr__(self, attr): - return getattr(self.magnitude, attr) + def __array__(self): + return np.asarray(self.magnitude) - def __getitem__(self, item): - return self.magnitude[item] +# Tests that the conversion machinery works properly for classes that +# work as a facade over numpy arrays (like pint) +@image_comparison(baseline_images=['plot_pint'], + extensions=['png'], remove_text=False, style='mpl20') +def test_numpy_facade(): # Create an instance of the conversion interface and # mock so we can check methods called qc = munits.ConversionInterface() def convert(value, unit, axis): if hasattr(value, 'units'): - return value.to(unit) + return value.to(unit).magnitude + elif iterable(value): + try: + return [v.to(unit).magnitude for v in value] + except AttributeError: + return [Quantity(v, axis.get_units()).to(unit).magnitude + for v in value] else: return Quantity(value, axis.get_units()).to(unit).magnitude qc.convert = MagicMock(side_effect=convert) - qc.axisinfo = MagicMock(return_value=None) + qc.axisinfo = MagicMock(side_effect=lambda u, a: munits.AxisInfo(label=u)) qc.default_units = MagicMock(side_effect=lambda x, a: x.units) # Register the class munits.registry[Quantity] = qc # Simple test - t = Quantity(np.linspace(0, 10), 'sec') - d = Quantity(30 * np.linspace(0, 10), 'm/s') + y = Quantity(np.linspace(0, 30), 'miles') + x = Quantity(np.linspace(0, 5), 'hours') - fig, ax = plt.subplots(1, 1) - l, = plt.plot(t, d) - ax.yaxis.set_units('inch') + fig, ax = plt.subplots() + fig.subplots_adjust(left=0.15) # Make space for label + ax.plot(x, y, 'tab:blue') + ax.axhline(Quantity(26400, 'feet'), color='tab:red') + ax.axvline(Quantity(120, 'minutes'), color='tab:green') + ax.yaxis.set_units('inches') + ax.xaxis.set_units('seconds') assert qc.convert.called assert qc.axisinfo.called assert qc.default_units.called + + +# Tests gh-8908 +@image_comparison(baseline_images=['plot_masked_units'], + extensions=['png'], remove_text=True, style='mpl20') +def test_plot_masked_units(): + data = np.linspace(-5, 5) + data_masked = np.ma.array(data, mask=(data > -2) & (data < 2)) + data_masked_units = Quantity(data_masked, 'meters') + + fig, ax = plt.subplots() + ax.plot(data_masked_units) diff --git a/lib/mpl_toolkits/mplot3d/axes3d.py b/lib/mpl_toolkits/mplot3d/axes3d.py index f387e5a5582e..4738be075f2f 100644 --- a/lib/mpl_toolkits/mplot3d/axes3d.py +++ b/lib/mpl_toolkits/mplot3d/axes3d.py @@ -110,13 +110,13 @@ def __init__(self, fig, rect=None, *args, **kwargs): # func used to format z -- fall back on major formatters self.fmt_zdata = None - if zscale is not None : + if zscale is not None: self.set_zscale(zscale) - if self.zaxis is not None : - self._zcid = self.zaxis.callbacks.connect('units finalize', - self.relim) - else : + if self.zaxis is not None: + self._zcid = self.zaxis.callbacks.connect( + 'units finalize', lambda: self._on_units_changed(scalez=True)) + else: self._zcid = None self._ready = 1 @@ -307,6 +307,15 @@ def get_axis_position(self): zhigh = tc[0][2] > tc[2][2] return xhigh, yhigh, zhigh + def _on_units_changed(self, scalex=False, scaley=False, scalez=False): + """ + Callback for processing changes to axis units. + + Currently forces updates of data limits and view limits. + """ + self.relim() + self.autoscale_view(scalex=scalex, scaley=scaley, scalez=scalez) + def update_datalim(self, xys, **kwargs): pass