Skip to content

BUG: Fix weird behavior with mask and units (Fixes #8908) #9049

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 29, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions lib/matplotlib/axes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,12 +556,12 @@ def __init__(self, fig, rect,
self.update(kwargs)

if self.xaxis is not None:
self._xcid = self.xaxis.callbacks.connect('units finalize',
self.relim)
self._xcid = self.xaxis.callbacks.connect(
'units finalize', lambda: self._on_units_changed(scalex=True))

if self.yaxis is not None:
self._ycid = self.yaxis.callbacks.connect('units finalize',
self.relim)
self._ycid = self.yaxis.callbacks.connect(
'units finalize', lambda: self._on_units_changed(scaley=True))

self.tick_params(
top=rcParams['xtick.top'] and rcParams['xtick.minor.top'],
Expand Down Expand Up @@ -1891,6 +1891,15 @@ def add_container(self, container):
container.set_remove_method(lambda h: self.containers.remove(h))
return container

def _on_units_changed(self, scalex=False, scaley=False):
"""
Callback for processing changes to axis units.

Currently forces updates of data limits and view limits.
"""
self.relim()
self.autoscale_view(scalex=scalex, scaley=scaley)

def relim(self, visible_only=False):
"""
Recompute the data limits based on current artists. If you want to
Expand Down
13 changes: 12 additions & 1 deletion lib/matplotlib/cbook/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1966,6 +1966,17 @@ def is_math_text(s):
return even_dollars


def _to_unmasked_float_array(x):
"""
Convert a sequence to a float array; if input was a masked array, masked
values are converted to nans.
"""
if hasattr(x, 'mask'):
return np.ma.asarray(x, float).filled(np.nan)
else:
return np.asarray(x, float)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is asanyarray better here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The existing code was always going to a straight-up numpy array, and since that part of the behavior didn't seem to be causing any unit-related problems, it seemed best to be conservative and keep it as asarray.



def _check_1d(x):
'''
Converts a sequence of less than 1 dimension, to an array of 1
Expand Down Expand Up @@ -2252,7 +2263,7 @@ def index_of(y):
try:
return y.index.values, y.values
except AttributeError:
y = np.atleast_1d(y)
y = _check_1d(y)
return np.arange(y.shape[0], dtype=float), y


Expand Down
15 changes: 4 additions & 11 deletions lib/matplotlib/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
from . import artist, colors as mcolors, docstring, rcParams
from .artist import Artist, allow_rasterization
from .cbook import (
iterable, is_numlike, ls_mapper, ls_mapper_r, STEP_LOOKUP_MAP)
_to_unmasked_float_array, iterable, is_numlike, ls_mapper, ls_mapper_r,
STEP_LOOKUP_MAP)
from .markers import MarkerStyle
from .path import Path
from .transforms import Bbox, TransformedPath, IdentityTransform
Expand Down Expand Up @@ -648,20 +649,12 @@ def recache_always(self):
def recache(self, always=False):
if always or self._invalidx:
xconv = self.convert_xunits(self._xorig)
if isinstance(self._xorig, np.ma.MaskedArray):
x = np.ma.asarray(xconv, float).filled(np.nan)
else:
x = np.asarray(xconv, float)
x = x.ravel()
x = _to_unmasked_float_array(xconv).ravel()
else:
x = self._x
if always or self._invalidy:
yconv = self.convert_yunits(self._yorig)
if isinstance(self._yorig, np.ma.MaskedArray):
y = np.ma.asarray(yconv, float).filled(np.nan)
else:
y = np.asarray(yconv, float)
y = y.ravel()
y = _to_unmasked_float_array(yconv).ravel()
else:
y = self._y

Expand Down
15 changes: 4 additions & 11 deletions lib/matplotlib/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
import numpy as np

from . import _path, rcParams
from .cbook import simple_linear_interpolation, maxdict
from .cbook import (_to_unmasked_float_array, simple_linear_interpolation,
maxdict)


class Path(object):
Expand Down Expand Up @@ -129,11 +130,7 @@ def __init__(self, vertices, codes=None, _interpolation_steps=1,
Makes the path behave in an immutable way and sets the vertices
and codes as read-only arrays.
"""
if isinstance(vertices, np.ma.MaskedArray):
vertices = vertices.astype(float).filled(np.nan)
else:
vertices = np.asarray(vertices, float)

vertices = _to_unmasked_float_array(vertices)
if (vertices.ndim != 2) or (vertices.shape[1] != 2):
msg = "'vertices' must be a 2D list or array with shape Nx2"
raise ValueError(msg)
Expand Down Expand Up @@ -185,11 +182,7 @@ def _fast_from_codes_and_verts(cls, verts, codes, internals=None):
"""
internals = internals or {}
pth = cls.__new__(cls)
if isinstance(verts, np.ma.MaskedArray):
verts = verts.astype(float).filled(np.nan)
else:
verts = np.asarray(verts, float)
pth._vertices = verts
pth._vertices = _to_unmasked_float_array(verts)
pth._codes = codes
pth._readonly = internals.pop('readonly', False)
pth.should_simplify = internals.pop('should_simplify', True)
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
79 changes: 58 additions & 21 deletions lib/matplotlib/tests/test_units.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from matplotlib.cbook import iterable
import matplotlib.pyplot as plt
from matplotlib.testing.decorators import image_comparison
import matplotlib.units as munits
import numpy as np

Expand All @@ -9,49 +11,84 @@
from mock import MagicMock


# Tests that the conversion machinery works properly for classes that
# work as a facade over numpy arrays (like pint)
def test_numpy_facade():
# Basic class that wraps numpy array and has units
class Quantity(object):
def __init__(self, data, units):
self.magnitude = data
self.units = units
# Basic class that wraps numpy array and has units
class Quantity(object):
def __init__(self, data, units):
self.magnitude = data
self.units = units

def to(self, new_units):
factors = {('hours', 'seconds'): 3600, ('minutes', 'hours'): 1 / 60,
('minutes', 'seconds'): 60, ('feet', 'miles'): 1 / 5280.,
('feet', 'inches'): 12, ('miles', 'inches'): 12 * 5280}
if self.units != new_units:
mult = factors[self.units, new_units]
return Quantity(mult * self.magnitude, new_units)
else:
return Quantity(self.magnitude, self.units)

def __getattr__(self, attr):
return getattr(self.magnitude, attr)

def to(self, new_units):
return Quantity(self.magnitude, new_units)
def __getitem__(self, item):
return Quantity(self.magnitude[item], self.units)

def __getattr__(self, attr):
return getattr(self.magnitude, attr)
def __array__(self):
return np.asarray(self.magnitude)

def __getitem__(self, item):
return self.magnitude[item]

# Tests that the conversion machinery works properly for classes that
# work as a facade over numpy arrays (like pint)
@image_comparison(baseline_images=['plot_pint'],
extensions=['png'], remove_text=False, style='mpl20')
def test_numpy_facade():
# Create an instance of the conversion interface and
# mock so we can check methods called
qc = munits.ConversionInterface()

def convert(value, unit, axis):
if hasattr(value, 'units'):
return value.to(unit)
return value.to(unit).magnitude
elif iterable(value):
try:
return [v.to(unit).magnitude for v in value]
except AttributeError:
return [Quantity(v, axis.get_units()).to(unit).magnitude
for v in value]
else:
return Quantity(value, axis.get_units()).to(unit).magnitude

qc.convert = MagicMock(side_effect=convert)
qc.axisinfo = MagicMock(return_value=None)
qc.axisinfo = MagicMock(side_effect=lambda u, a: munits.AxisInfo(label=u))
qc.default_units = MagicMock(side_effect=lambda x, a: x.units)

# Register the class
munits.registry[Quantity] = qc

# Simple test
t = Quantity(np.linspace(0, 10), 'sec')
d = Quantity(30 * np.linspace(0, 10), 'm/s')
y = Quantity(np.linspace(0, 30), 'miles')
x = Quantity(np.linspace(0, 5), 'hours')

fig, ax = plt.subplots(1, 1)
l, = plt.plot(t, d)
ax.yaxis.set_units('inch')
fig, ax = plt.subplots()
fig.subplots_adjust(left=0.15) # Make space for label
ax.plot(x, y, 'tab:blue')
ax.axhline(Quantity(26400, 'feet'), color='tab:red')
ax.axvline(Quantity(120, 'minutes'), color='tab:green')
ax.yaxis.set_units('inches')
ax.xaxis.set_units('seconds')

assert qc.convert.called
assert qc.axisinfo.called
assert qc.default_units.called


# Tests gh-8908
@image_comparison(baseline_images=['plot_masked_units'],
extensions=['png'], remove_text=True, style='mpl20')
def test_plot_masked_units():
data = np.linspace(-5, 5)
data_masked = np.ma.array(data, mask=(data > -2) & (data < 2))
data_masked_units = Quantity(data_masked, 'meters')

fig, ax = plt.subplots()
ax.plot(data_masked_units)
19 changes: 14 additions & 5 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,13 +110,13 @@ def __init__(self, fig, rect=None, *args, **kwargs):
# func used to format z -- fall back on major formatters
self.fmt_zdata = None

if zscale is not None :
if zscale is not None:
self.set_zscale(zscale)

if self.zaxis is not None :
self._zcid = self.zaxis.callbacks.connect('units finalize',
self.relim)
else :
if self.zaxis is not None:
self._zcid = self.zaxis.callbacks.connect(
'units finalize', lambda: self._on_units_changed(scalez=True))
else:
self._zcid = None

self._ready = 1
Expand Down Expand Up @@ -307,6 +307,15 @@ def get_axis_position(self):
zhigh = tc[0][2] > tc[2][2]
return xhigh, yhigh, zhigh

def _on_units_changed(self, scalex=False, scaley=False, scalez=False):
"""
Callback for processing changes to axis units.

Currently forces updates of data limits and view limits.
"""
self.relim()
self.autoscale_view(scalex=scalex, scaley=scaley, scalez=scalez)

def update_datalim(self, xys, **kwargs):
pass

Expand Down