Skip to content

ENH: implement __skip_array_function__ attribute for NEP-18 #13389

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 18, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions numpy/core/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,8 @@ def {name}(*args, **kwargs):
if module is not None:
public_api.__module__ = module

public_api.__skip_array_function__ = implementation

return public_api
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will not merge cleanly after #13529

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed


return decorator
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/arrayfunction_override.c
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,7 @@ array_function_method_impl(PyObject *func, PyObject *types, PyObject *args,
}
}

implementation = PyObject_GetAttr(func, npy_ma_str_wrapped);
implementation = PyObject_GetAttr(func, npy_ma_str_skip_array_function);
if (implementation == NULL) {
return NULL;
}
Expand Down
7 changes: 4 additions & 3 deletions numpy/core/src/multiarray/multiarraymodule.c
Original file line number Diff line number Diff line change
Expand Up @@ -4498,7 +4498,7 @@ NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_array_prepare = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_array_wrap = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_array_finalize = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_ufunc = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_wrapped = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_skip_array_function = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_order = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_copy = NULL;
NPY_VISIBILITY_HIDDEN PyObject * npy_ma_str_dtype = NULL;
Expand All @@ -4514,7 +4514,8 @@ intern_strings(void)
npy_ma_str_array_wrap = PyUString_InternFromString("__array_wrap__");
npy_ma_str_array_finalize = PyUString_InternFromString("__array_finalize__");
npy_ma_str_ufunc = PyUString_InternFromString("__array_ufunc__");
npy_ma_str_wrapped = PyUString_InternFromString("__wrapped__");
npy_ma_str_skip_array_function = PyUString_InternFromString(
"__skip_array_function__");
npy_ma_str_order = PyUString_InternFromString("order");
npy_ma_str_copy = PyUString_InternFromString("copy");
npy_ma_str_dtype = PyUString_InternFromString("dtype");
Expand All @@ -4524,7 +4525,7 @@ intern_strings(void)

return npy_ma_str_array && npy_ma_str_array_prepare &&
npy_ma_str_array_wrap && npy_ma_str_array_finalize &&
npy_ma_str_ufunc && npy_ma_str_wrapped &&
npy_ma_str_ufunc && npy_ma_str_skip_array_function &&
npy_ma_str_order && npy_ma_str_copy && npy_ma_str_dtype &&
npy_ma_str_ndmin && npy_ma_str_axis1 && npy_ma_str_axis2;
}
Expand Down
2 changes: 1 addition & 1 deletion numpy/core/src/multiarray/multiarraymodule.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_array_prepare;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_array_wrap;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_array_finalize;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_ufunc;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_wrapped;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_skip_array_function;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_order;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_copy;
NPY_VISIBILITY_HIDDEN extern PyObject * npy_ma_str_dtype;
Expand Down
62 changes: 57 additions & 5 deletions numpy/core/tests/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect
import sys
from unittest import mock

import numpy as np
from numpy.testing import (
Expand All @@ -10,7 +11,6 @@
_get_implementing_args, array_function_dispatch,
verify_matching_signatures)
from numpy.compat import pickle
import pytest


def _return_not_implemented(self, *args, **kwargs):
Expand Down Expand Up @@ -190,12 +190,18 @@ class OverrideSub(np.ndarray):
result = np.concatenate((array, override_sub))
assert_equal(result, expected.view(OverrideSub))

def test_skip_array_function(self):
assert_(dispatched_one_arg.__skip_array_function__
is dispatched_one_arg.__wrapped__)

def test_no_wrapper(self):
# This shouldn't happen unless a user intentionally calls
# __array_function__ with invalid arguments, but check that we raise
# an appropriate error all the same.
array = np.array(1)
func = dispatched_one_arg.__wrapped__
with assert_raises_regex(AttributeError, '__wrapped__'):
array.__array_function__(func=func,
types=(np.ndarray,),
func = dispatched_one_arg.__skip_array_function__
with assert_raises_regex(AttributeError, '__skip_array_function__'):
array.__array_function__(func=func, types=(np.ndarray,),
args=(array,), kwargs={})


Expand Down Expand Up @@ -378,3 +384,49 @@ def _(array):
return 'yes'

assert_equal(np.sum(MyArray()), 'yes')

def test_sum_implementation_on_list(self):
assert_equal(np.sum.__skip_array_function__([1, 2, 3]), 6)

def test_sum_on_mock_array(self):

# We need a proxy for mocks because __array_function__ is only looked
# up in the class dict
class ArrayProxy:
def __init__(self, value):
self.value = value
def __array_function__(self, *args, **kwargs):
return self.value.__array_function__(*args, **kwargs)
def __array__(self, *args, **kwargs):
return self.value.__array__(*args, **kwargs)

proxy = ArrayProxy(mock.Mock(spec=ArrayProxy))
proxy.value.__array_function__.return_value = 1
result = np.sum(proxy)
assert_equal(result, 1)
proxy.value.__array_function__.assert_called_once_with(
np.sum, (ArrayProxy,), (proxy,), {})
proxy.value.__array__.assert_not_called()

proxy = ArrayProxy(mock.Mock(spec=ArrayProxy))
proxy.value.__array__.return_value = np.array(2)
result = np.sum.__skip_array_function__(proxy)
assert_equal(result, 2)
# TODO: switch to proxy.value.__array__.assert_called() and
# proxy.value.__array_function__.assert_not_called() once we drop
# Python 3.5 support.
((called_method_name, _, _),) = proxy.value.mock_calls
assert_equal(called_method_name, '__array__')

def test_sum_forwarding_implementation(self):

class MyArray(object):

def sum(self, axis, out):
return 'summed'

def __array_function__(self, func, types, args, kwargs):
return func.__skip_array_function__(*args, **kwargs)

# note: the internal implementation of np.sum() calls the .sum() method
assert_equal(np.sum(MyArray()), 'summed')