Skip to content

Functions select and where don't preserve subclasses #10933

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
fpagnoux opened this issue Apr 19, 2018 · 4 comments
Closed

Functions select and where don't preserve subclasses #10933

fpagnoux opened this issue Apr 19, 2018 · 4 comments

Comments

@fpagnoux
Copy link

Hey there 👋 !

I'm using numpy 1.14.2 with python 2.7.10.

I'm trying to subclass numpy.ndarray, and it seems like, unlike most other operators/functions, numpy.select and numpy.where don't preserve the sublass of an array. Here is an exemple:

import numpy as np

# I create a sublass of np.ndarray containing a custom attribute
class CustomArray(np.ndarray):
    def __new__(cls, matrix, custom_attribute):
        obj = np.asarray(matrix).view(cls)
        obj.custom_attribute = custom_attribute
        return obj

    def __array_finalize__(self, obj):
        if obj is None:
            return
        self.custom_attribute = getattr(obj, 'custom_attribute', None)

x = CustomArray(np.asarray([1,2,3]), 'some_value')  # a custom array
y = np.asarray([4,5,6])  # a regular array

# Sublass is preserved by most operations
assert type(x + y) == CustomArray
assert type(x * y) == CustomArray
assert type(np.maximum(x, y)) == CustomArray
assert type(np.logical_not((x > 1) * (y <= 5))) == CustomArray

# But not by where and select, even if the only arrays involved are all CustomArrays

where_result = np.where(x <= 1, x + 1, x)

select_result = np.select(
  [[x <= 1], [x <= 2]],
  [[x + 1], [x + 2]],
  x
  )

# The 2 following assertions fail
assert type(where_result) == CustomArray
assert type(select_result) == CustomArray

Is that an intended behavior ?

Thanks a lot !

@eric-wieser
Copy link
Member

Related: #8994

@mhvk
Copy link
Contributor

mhvk commented Apr 20, 2018

Longer answer: no, not really intended. See also an earlier report by myself, #5095 (part of what got me into trying to fix pieces of numpy, better be careful!)

@cournape
Copy link
Member

We looked into this at the MAN/AHL pydata hackathon, and for where, I am not sure about how to go about it as I am not so familiar w/ the relationship between subclasses, __array_priority__ and nditer.

To make a long story short:

  1. PyArray_Where uses PyArray_Nonzero for where(cond), and nditer C API for the form where(cond, x, y).
  2. Fixing PyArray_Nonzero to preserve subclasses is easy
  3. It sounds like nditer uses __array_priority__ to decide the output class.

If 3 is True, then one would need to define __array_priority__ in CustomArray class to get the "expected" behavior. I am not sure there is anything else to do, short of reimplementing PyArray_Where as a ufunc, which is the right long term solution ?

@mhvk
Copy link
Contributor

mhvk commented Apr 22, 2018

@cournape - the ideal solution would indeed be to implement the 3-form version as a ufunc, since the the new override mechanism (__array_ufunc__) could be used. See #8994. But note that just getting the right output class is not enough: the output class also has to be able to intercept how setting of the elements is done (see original issue #5095).

I'm a bit worried that we're now spreading discussion over many different issues, so I think I'll close this one. Please do add anything relevant to #8994.

@mhvk mhvk closed this as completed Apr 22, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

4 participants