Skip to content

ENH: __array_function__ support for np.lib, part 2/2 #12119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Oct 23, 2018

Conversation

shoyer
Copy link
Member

@shoyer shoyer commented Oct 8, 2018

xref #12028

np.lib.npyio through np.lib.ufunclike

xref GH12028

np.lib.npyio through np.lib.ufunclike
@shoyer shoyer changed the title ENH: __array_function__ support for np.lib, part 2 ENH: __array_function__ support for np.lib, part 2/2 Oct 8, 2018
@shoyer
Copy link
Member Author

shoyer commented Oct 8, 2018

There's some weird interaction between iscomplexobj and assert_equal (which internally calls iscomplexobj) that is causing the observed test failures, e.g.,


self = <numpy.core.tests.test_overrides.TestArrayFunctionDispatch object at 0x7f8910047ba8>
    def test_interface(self):
    
        class MyArray(object):
            def __array_function__(self, func, types, args, kwargs):
                return (self, func, types, args, kwargs)
    
        original = MyArray()
        (obj, func, types, args, kwargs) = dispatched_one_arg(original)
        assert_(obj is original)
        assert_(func is dispatched_one_arg)
        assert_equal(set(types), {MyArray})
>       assert_equal(args, (original,))
MyArray    = <class 'numpy.core.tests.test_overrides.TestArrayFunctionDispatch.test_interface.<locals>.MyArray'>
args       = (<numpy.core.tests.test_overrides.TestArrayFunctionDispatch.test_interface.<locals>.MyArray object at 0x7f8910047898>,)
func       = <function dispatched_one_arg at 0x7f8913644e18>
kwargs     = {}
obj        = <numpy.core.tests.test_overrides.TestArrayFunctionDispatch.test_interface.<locals>.MyArray object at 0x7f8910047898>
original   = <numpy.core.tests.test_overrides.TestArrayFunctionDispatch.test_interface.<locals>.MyArray object at 0x7f8910047898>
self       = <numpy.core.tests.test_overrides.TestArrayFunctionDispatch object at 0x7f8910047ba8>
types      = [<class 'numpy.core.tests.test_overrides.TestArrayFunctionDispatch.test_interface.<locals>.MyArray'>]
../builds/venv/lib/python3.6/site-packages/numpy/core/tests/test_overrides.py:190: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
../builds/venv/lib/python3.6/site-packages/numpy/core/overrides.py:151: in public_api
    implementation, public_api, relevant_args, args, kwargs)
../builds/venv/lib/python3.6/site-packages/numpy/core/overrides.py:96: in array_function_implementation_or_override
    return implementation(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
x = 5
    @array_function_dispatch(_iscomplexobj_dispatcher)
    def iscomplexobj(x):
        """
        Check for a complex type or an array of complex numbers.
    
        The type of the input is checked, not the value. Even if the input
        has an imaginary part equal to zero, `iscomplexobj` evaluates to True.
    
        Parameters
        ----------
        x : any
            The input can be of any type and shape.
    
        Returns
        -------
        iscomplexobj : bool
            The return value, True if `x` is of a complex type or has at least
            one complex element.
    
        See Also
        --------
        isrealobj, iscomplex
    
        Examples
        --------
        >>> np.iscomplexobj(1)
        False
        >>> np.iscomplexobj(1+0j)
        True
        >>> np.iscomplexobj([3, 1+0j, True])
        True
    
        """
        try:
>           dtype = x.dtype
E           RecursionError: maximum recursion depth exceeded
x          = 5
../builds/venv/lib/python3.6/site-packages/numpy/lib/type_check.py:311: RecursionError

@shoyer
Copy link
Member Author

shoyer commented Oct 8, 2018

The problem seems to be that my dummy __array_function__ implementation returns self for all functions, which then breaks assert_equal when iscomplexobj() is overloaded.

I see three possible ways to fix this:

  1. adjust my __array_function__ implementation in test_overrides.py to special case np.iscomplexobj
  2. use assert_(args == (original,)) instead of assert_equal()
  3. add more fallback logic to assert_equal() to handle cases where iscomplexobj() returns a nonsensical result (i.e., a non-boolean)

@shoyer
Copy link
Member Author

shoyer commented Oct 9, 2018

I decided to go for a mix of 2 and 3:

  • assert_equal should be OK with np.iscomplexobj raising TypeError.
  • for the one case where I define __array_function__ for everything, I use assert_ instead of assert_equal.

Copy link
Contributor

@hameerabbasi hameerabbasi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We perhaps should get rid of the hack sometime, but other than that I'm okay with this.

return (x, out)


@array_function_dispatch(_fix_dispatcher, verify=False)
@_deprecate_out_named_y
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be added to the dispatcher?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is a better idea. That way the warning will be raised even if the function is dispatching.

@shoyer
Copy link
Member Author

shoyer commented Oct 11, 2018

Note: This one still needs more work before it's ready to merge (to fix the deprecation warnings).

@charris charris changed the title ENH: __array_function__ support for np.lib, part 2/2 WIP, ENH: __array_function__ support for np.lib, part 2/2 Oct 11, 2018
@shoyer
Copy link
Member Author

shoyer commented Oct 12, 2018

OK, inspired by @mhvk's and @hameerabbasi's comments I consolidated a bunch of unnecessarily repeated dispatcher functions here, too.

@shoyer shoyer changed the title WIP, ENH: __array_function__ support for np.lib, part 2/2 ENH: __array_function__ support for np.lib, part 2/2 Oct 12, 2018
@shoyer
Copy link
Member Author

shoyer commented Oct 16, 2018

This is ready for a final review.

Copy link
Contributor

@hameerabbasi hameerabbasi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One minor bug I spotted as I was going through this.

@@ -309,6 +323,12 @@ def log10(x):
x = _fix_real_lt_zero(x)
return nx.log10(x)


def _logn_dispatcher(n, x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should return both arguments -- Can have logs of different bases as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes -- I think I missed this because I only read the docstring, which says n is an int.

Copy link
Contributor

@mhvk mhvk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but some suggestions for using decorators for a group.

Also, more important, the last comment about the change in testing utils.py - that was really tricky to get right; I would suggest not to change it unless truly needed.

@@ -712,7 +759,12 @@ def array_split(ary, indices_or_sections, axis=0):
return sub_arys


def split(ary,indices_or_sections,axis=0):
def _split_dispatcher(ary, indices_or_sections, axis=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here you could use a single _split_dispatcher with the function above as well as below (where the function currently gets redefined, but then it does get reused a few times).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch -- fixed

@@ -373,6 +386,11 @@ def tri(N, M=None, k=0, dtype=float):
return m


def _tril_dispatcher(m, k=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call it triul_dispatcher, since it is used for both?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -812,6 +842,11 @@ def tril_indices(n, k=0, m=None):
return nonzero(tri(n, m, k=k, dtype=bool))


def _tril_indices_form_dispatcher(arr, k=None):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same, either _tri or _triul

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -183,6 +194,11 @@ def imag(val):
return asanyarray(val).imag


def _iscomplex_dispatcher(x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Next few could be common _is_type_dispatcher?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done!

@@ -713,7 +713,7 @@ def func_assert_same_pos(x, y, func=isnan, hasval='nan'):
# such subclasses, but some used to work.
x_id = func(x)
y_id = func(y)
if npall(x_id == y_id) != True:
if (x_id == y_id).all() != True:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the comment above, I think this should not be changed... Alternatively, at least the comment should be adjusted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. The problem is that now np.all() may likely not be implemented for an ndarray subclass, if it doesn't define np.all in __array_function__.

I think we can just delete this comment, because it's no longer true that np.all() can work when a subclass implements.all() differently -- we already use getattr() to pull out a all() method to call if one exists.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I worry is that this change is a very recent one, by @charris (#11756), which correctly something that .all() didn't catch. (Sorry, have to run, not sure what the real issue was...)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I see, thanks. I think the comment was a little misleading here -- I think the problem was classes that define equality differently (e.g., to return a boolean) rather than classes that don't define an .all() method.

if npall(x_id == y_id) != True:
result = x_id == y_id
all_equal = (result.all()
if isinstance(result, ndarray)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this pass through subclasses? And if so, will that be an issue?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with the worry of @hameerabbasi - also, at least the comment reads a bit weird now: I don't think we have to worry about classed that define __array_function__ but do not override np.all - they're new and they can override easily as part of their trials; maybe just leave it as it was?

Copy link
Member Author

@shoyer shoyer Oct 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this passes through subclasses -- which is the point.

I don't think we have to worry about classed that define array_function but do not override np.all - they're new and they can override easily as part of their trials; maybe just leave it as it was?

I don't exactly disagree with this, but this did come up with the ndarray subclass I made for unit testing purposes. With the arrival of __array_function__, it will be easier to depend on built-in methods like all() rather than np.all(). I guess we offer no guarantees for subclasses that don't implement the full numpy API, but it still feels like not a great user experience.

Two other ways to do this:

  1. Cast to a NumPy boolean scalar/array, which guarantees that we have an all() method: np.bool_(x_id == y_id).all(). This looks less hacky and would still work for all the identified use-cases.
  2. Give up on duck-typing for assert_array_equal, and instead decorate it with array_function_dispatch for explicit dispatching.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, if this won't cause any problems, then I agree with this design.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option 1 seems to cover all bases, and is shorter than what you have now.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. Testing out masked arrays, they no longer seem to return masked arrays from .all() or np.all():

In [16]: x = np.ma.array([1, 2, 3], mask=[False, True, False])

In [17]: x == x
Out[17]:
masked_array(data=[True, --, True],
             mask=[False,  True, False],
       fill_value=True)

In [18]: (x == x).all()
Out[18]: True

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should cause a test failure, and if there is no test like that we should add one. Also note there are test failures with datetime64

Copy link
Contributor

@hameerabbasi hameerabbasi Oct 17, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That should cause a test failure, and if there is no test like that we should add one. Also note there are test failures with datetime64

I'm not sure. What would you do with a mask at all on a reduction? I mean, it makes sense to only reduce over the non-masked objects, but what is the output mask? For zero nonmasked elements do you set it to the identity of the ufunc or to False? Or do you just mask everything where any element is masked? Likely not the right answer.

@mhvk
Copy link
Contributor

mhvk commented Oct 17, 2018

Wow, that should cover it! Thanks, @shoyer!

@shoyer
Copy link
Member Author

shoyer commented Oct 19, 2018

OK, I plan to merge this shortly unless I get more feedback.

@shoyer shoyer removed the 25 - WIP label Oct 23, 2018
@shoyer shoyer merged commit 7315145 into numpy:master Oct 23, 2018
@shoyer shoyer deleted the array-function-numpy-lib2 branch October 23, 2018 00:40
@charris charris changed the title ENH: __array_function__ support for np.lib, part 2/2 ENH: __array_function__ support for np.lib, part 2/2 Nov 10, 2018
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants