diff --git a/numpy/core/overrides.py b/numpy/core/overrides.py index 17e3d475f86e..906292613b2f 100644 --- a/numpy/core/overrides.py +++ b/numpy/core/overrides.py @@ -2,8 +2,11 @@ TODO: rewrite this in C for performance. """ +import collections import functools + from numpy.core.multiarray import ndarray +from numpy.compat._inspect import getargspec _NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__ @@ -107,13 +110,45 @@ def array_function_implementation_or_override( .format(public_api, list(map(type, overloaded_args)))) -def array_function_dispatch(dispatcher): +ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') + + +def verify_matching_signatures(implementation, dispatcher): + """Verify that a dispatcher function has the right signature.""" + implementation_spec = ArgSpec(*getargspec(implementation)) + dispatcher_spec = ArgSpec(*getargspec(dispatcher)) + + if (implementation_spec.args != dispatcher_spec.args or + implementation_spec.varargs != dispatcher_spec.varargs or + implementation_spec.keywords != dispatcher_spec.keywords or + (bool(implementation_spec.defaults) != + bool(dispatcher_spec.defaults)) or + (implementation_spec.defaults is not None and + len(implementation_spec.defaults) != + len(dispatcher_spec.defaults))): + raise RuntimeError('implementation and dispatcher for %s have ' + 'different function signatures' % implementation) + + if implementation_spec.defaults is not None: + if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults): + raise RuntimeError('dispatcher functions can only use None for ' + 'default argument values') + + +def array_function_dispatch(dispatcher, verify=True): """Decorator for adding dispatch with the __array_function__ protocol.""" def decorator(implementation): + # TODO: only do this check when the appropriate flag is enabled or for + # a dev install. We want this check for testing but don't want to + # slow down all numpy imports. + if verify: + verify_matching_signatures(implementation, dispatcher) + @functools.wraps(implementation) def public_api(*args, **kwargs): relevant_args = dispatcher(*args, **kwargs) return array_function_implementation_or_override( implementation, public_api, relevant_args, args, kwargs) return public_api + return decorator diff --git a/numpy/core/tests/test_overrides.py b/numpy/core/tests/test_overrides.py index 7f6157a5bf09..895f221daf0b 100644 --- a/numpy/core/tests/test_overrides.py +++ b/numpy/core/tests/test_overrides.py @@ -7,7 +7,8 @@ from numpy.testing import ( assert_, assert_equal, assert_raises, assert_raises_regex) from numpy.core.overrides import ( - get_overloaded_types_and_args, array_function_dispatch) + get_overloaded_types_and_args, array_function_dispatch, + verify_matching_signatures) def _get_overloaded_args(relevant_args): @@ -200,6 +201,36 @@ def __array_function__(self, func, types, args, kwargs): dispatched_one_arg(array) +class TestVerifyMatchingSignatures(object): + + def test_verify_matching_signatures(self): + + verify_matching_signatures(lambda x: 0, lambda x: 0) + verify_matching_signatures(lambda x=None: 0, lambda x=None: 0) + verify_matching_signatures(lambda x=1: 0, lambda x=None: 0) + + with assert_raises(RuntimeError): + verify_matching_signatures(lambda a: 0, lambda b: 0) + with assert_raises(RuntimeError): + verify_matching_signatures(lambda x: 0, lambda x=None: 0) + with assert_raises(RuntimeError): + verify_matching_signatures(lambda x=None: 0, lambda y=None: 0) + with assert_raises(RuntimeError): + verify_matching_signatures(lambda x=1: 0, lambda y=1: 0) + + def test_array_function_dispatch(self): + + with assert_raises(RuntimeError): + @array_function_dispatch(lambda x: (x,)) + def f(y): + pass + + # should not raise + @array_function_dispatch(lambda x: (x,), verify=False) + def f(y): + pass + + def _new_duck_type_and_implements(): """Create a duck array type and implements functions.""" HANDLED_FUNCTIONS = {}