Skip to content

Interpret data to normalize as ndarrays #8696

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 7, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions lib/matplotlib/colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -907,7 +907,11 @@ def process_value(value):
if np.issubdtype(dtype, np.integer) or dtype.type is np.bool_:
# bool_/int8/int16 -> float32; int32/int64 -> float64
dtype = np.promote_types(dtype, np.float32)
result = np.ma.array(value, dtype=dtype, copy=True)
# ensure data passed in as an ndarray subclass are interpreted as
# an ndarray. See issue #6622.
mask = np.ma.getmask(value)
data = np.asarray(np.ma.getdata(value))
result = np.ma.array(data, mask=mask, dtype=dtype, copy=True)
return result, is_scalar

def __call__(self, value, clip=None):
Expand Down Expand Up @@ -937,9 +941,7 @@ def __call__(self, value, clip=None):
result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax),
mask=mask)
# ma division is very slow; we can take a shortcut
# use np.asarray so data passed in as an ndarray subclass are
# interpreted as an ndarray. See issue #6622.
resdat = np.asarray(result.data)
resdat = result.data
resdat -= vmin
resdat /= (vmax - vmin)
result = np.ma.array(resdat, mask=result.mask, copy=False)
Expand Down Expand Up @@ -1007,7 +1009,7 @@ def __call__(self, value, clip=None):
if clip:
mask = np.ma.getmask(result)
result = np.ma.array(np.clip(result.filled(vmax), vmin, vmax),
mask=mask)
mask=mask)
# in-place equivalent of above can be much faster
resdat = result.data
mask = result.mask
Expand Down
19 changes: 19 additions & 0 deletions lib/matplotlib/tests/test_colors.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,3 +689,22 @@ def test_tableau_order():
'#bcbd22', '#17becf']

assert list(mcolors.TABLEAU_COLORS.values()) == dflt_cycle


def test_ndarray_subclass_norm():
# Emulate an ndarray subclass that handles units
# which objects when adding or subtracting with other
# arrays. See #6622 and #8696
class MyArray(np.ndarray):
def __isub__(self, other):
raise RuntimeError

def __add__(self, other):
raise RuntimeError

data = np.arange(-10, 10, 1, dtype=float)

for norm in [mcolors.Normalize(), mcolors.LogNorm(),
mcolors.SymLogNorm(3, vmax=5, linscale=1),
mcolors.PowerNorm(1)]:
assert_array_equal(norm(data.view(MyArray)), norm(data))