Skip to content

ENH: Validate dispatcher functions in array_function_dispatch #12099

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 36 additions & 1 deletion numpy/core/overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

TODO: rewrite this in C for performance.
"""
import collections
import functools

from numpy.core.multiarray import ndarray
from numpy.compat._inspect import getargspec


_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__
Expand Down Expand Up @@ -107,13 +110,45 @@ def array_function_implementation_or_override(
.format(public_api, list(map(type, overloaded_args))))


def array_function_dispatch(dispatcher):
ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults')


def verify_matching_signatures(implementation, dispatcher):
"""Verify that a dispatcher function has the right signature."""
implementation_spec = ArgSpec(*getargspec(implementation))
dispatcher_spec = ArgSpec(*getargspec(dispatcher))

if (implementation_spec.args != dispatcher_spec.args or
implementation_spec.varargs != dispatcher_spec.varargs or
implementation_spec.keywords != dispatcher_spec.keywords or
(bool(implementation_spec.defaults) !=
bool(dispatcher_spec.defaults)) or
(implementation_spec.defaults is not None and
len(implementation_spec.defaults) !=
len(dispatcher_spec.defaults))):
raise RuntimeError('implementation and dispatcher for %s have '
'different function signatures' % implementation)

if implementation_spec.defaults is not None:
if dispatcher_spec.defaults != (None,) * len(dispatcher_spec.defaults):
raise RuntimeError('dispatcher functions can only use None for '
'default argument values')


def array_function_dispatch(dispatcher, verify=True):
"""Decorator for adding dispatch with the __array_function__ protocol."""
def decorator(implementation):
# TODO: only do this check when the appropriate flag is enabled or for
# a dev install. We want this check for testing but don't want to
# slow down all numpy imports.
if verify:
verify_matching_signatures(implementation, dispatcher)

@functools.wraps(implementation)
def public_api(*args, **kwargs):
relevant_args = dispatcher(*args, **kwargs)
return array_function_implementation_or_override(
implementation, public_api, relevant_args, args, kwargs)
return public_api

return decorator
33 changes: 32 additions & 1 deletion numpy/core/tests/test_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from numpy.testing import (
assert_, assert_equal, assert_raises, assert_raises_regex)
from numpy.core.overrides import (
get_overloaded_types_and_args, array_function_dispatch)
get_overloaded_types_and_args, array_function_dispatch,
verify_matching_signatures)


def _get_overloaded_args(relevant_args):
Expand Down Expand Up @@ -200,6 +201,36 @@ def __array_function__(self, func, types, args, kwargs):
dispatched_one_arg(array)


class TestVerifyMatchingSignatures(object):

def test_verify_matching_signatures(self):

verify_matching_signatures(lambda x: 0, lambda x: 0)
verify_matching_signatures(lambda x=None: 0, lambda x=None: 0)
verify_matching_signatures(lambda x=1: 0, lambda x=None: 0)

with assert_raises(RuntimeError):
verify_matching_signatures(lambda a: 0, lambda b: 0)
with assert_raises(RuntimeError):
verify_matching_signatures(lambda x: 0, lambda x=None: 0)
with assert_raises(RuntimeError):
verify_matching_signatures(lambda x=None: 0, lambda y=None: 0)
with assert_raises(RuntimeError):
verify_matching_signatures(lambda x=1: 0, lambda y=1: 0)

def test_array_function_dispatch(self):

with assert_raises(RuntimeError):
@array_function_dispatch(lambda x: (x,))
def f(y):
pass

# should not raise
@array_function_dispatch(lambda x: (x,), verify=False)
def f(y):
pass


def _new_duck_type_and_implements():
"""Create a duck array type and implements functions."""
HANDLED_FUNCTIONS = {}
Expand Down