From 0f055bbc4b1efa72a013a16b9d5d92f97b004cb4 Mon Sep 17 00:00:00 2001 From: Nathan Goldbaum Date: Wed, 31 May 2017 17:18:54 -0500 Subject: [PATCH] interpret data to normalize as ndarrays unless passed a masked array --- lib/matplotlib/colors.py | 12 +++++++----- lib/matplotlib/tests/test_colors.py | 19 +++++++++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/lib/matplotlib/colors.py b/lib/matplotlib/colors.py index f8f62c90c1d1..5d2cbb3bd463 100644 --- a/lib/matplotlib/colors.py +++ b/lib/matplotlib/colors.py @@ -907,7 +907,11 @@ def process_value(value): if np.issubdtype(dtype, np.integer) or dtype.type is np.bool_: # bool_/int8/int16 -> float32; int32/int64 -> float64 dtype = np.promote_types(dtype, np.float32) - result = np.ma.array(value, dtype=dtype, copy=True) + # ensure data passed in as an ndarray subclass are interpreted as + # an ndarray. See issue #6622. + mask = np.ma.getmask(value) + data = np.asarray(np.ma.getdata(value)) + result = np.ma.array(data, mask=mask, dtype=dtype, copy=True) return result, is_scalar def __call__(self, value, clip=None): @@ -937,9 +941,7 @@ def __call__(self, value, clip=None): result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax), mask=mask) # ma division is very slow; we can take a shortcut - # use np.asarray so data passed in as an ndarray subclass are - # interpreted as an ndarray. See issue #6622. - resdat = np.asarray(result.data) + resdat = result.data resdat -= vmin resdat /= (vmax - vmin) result = np.ma.array(resdat, mask=result.mask, copy=False) @@ -1007,7 +1009,7 @@ def __call__(self, value, clip=None): if clip: mask = np.ma.getmask(result) result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax), - mask=mask) + mask=mask) # in-place equivalent of above can be much faster resdat = result.data mask = result.mask diff --git a/lib/matplotlib/tests/test_colors.py b/lib/matplotlib/tests/test_colors.py index 0bf30becdd10..37cb2039e35e 100644 --- a/lib/matplotlib/tests/test_colors.py +++ b/lib/matplotlib/tests/test_colors.py @@ -689,3 +689,22 @@ def test_tableau_order(): '#bcbd22', '#17becf'] assert list(mcolors.TABLEAU_COLORS.values()) == dflt_cycle + + +def test_ndarray_subclass_norm(): + # Emulate an ndarray subclass that handles units + # which objects when adding or subtracting with other + # arrays. See #6622 and #8696 + class MyArray(np.ndarray): + def __isub__(self, other): + raise RuntimeError + + def __add__(self, other): + raise RuntimeError + + data = np.arange(-10, 10, 1, dtype=float) + + for norm in [mcolors.Normalize(), mcolors.LogNorm(), + mcolors.SymLogNorm(3, vmax=5, linscale=1), + mcolors.PowerNorm(1)]: + assert_array_equal(norm(data.view(MyArray)), norm(data))