Skip to content

ENH: initial implementation of core __array_function__ machinery #12005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Sep 24, 2018
Merged
61 changes: 61 additions & 0 deletions benchmarks/benchmarks/bench_overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from __future__ import absolute_import, division, print_function

from .common import Benchmark

from numpy.core.overrides import array_function_dispatch
import numpy as np


def _broadcast_to_dispatcher(array, shape, subok=None):
return (array,)


@array_function_dispatch(_broadcast_to_dispatcher)
def mock_broadcast_to(array, shape, subok=False):
pass


def _concatenate_dispatcher(arrays, axis=None, out=None):
for array in arrays:
yield array
if out is not None:
yield out


@array_function_dispatch(_concatenate_dispatcher)
def mock_concatenate(arrays, axis=0, out=None):
pass


class DuckArray(object):
def __array_function__(self, func, types, args, kwargs):
pass


class ArrayFunction(Benchmark):

def setup(self):
self.numpy_array = np.array(1)
self.numpy_arrays = [np.array(1), np.array(2)]
self.many_arrays = 500 * self.numpy_arrays
self.duck_array = DuckArray()
self.duck_arrays = [DuckArray(), DuckArray()]
self.mixed_arrays = [np.array(1), DuckArray()]

def time_mock_broadcast_to_numpy(self):
mock_broadcast_to(self.numpy_array, ())

def time_mock_broadcast_to_duck(self):
mock_broadcast_to(self.duck_array, ())

def time_mock_concatenate_numpy(self):
mock_concatenate(self.numpy_arrays, axis=0)

def time_mock_concatenate_many(self):
mock_concatenate(self.many_arrays, axis=0)

def time_mock_concatenate_duck(self):
mock_concatenate(self.duck_arrays, axis=0)

def time_mock_concatenate_mixed(self):
mock_concatenate(self.mixed_arrays, axis=0)
15 changes: 15 additions & 0 deletions numpy/core/_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,18 @@ def _ptp(a, axis=None, out=None, keepdims=False):
umr_minimum(a, axis, None, None, keepdims),
out
)

_NDARRAY_ARRAY_FUNCTION = mu.ndarray.__array_function__

def _array_function(self, func, types, args, kwargs):
# TODO: rewrite this in C
# Cannot handle items that have __array_function__ other than our own.
for t in types:
if t is not mu.ndarray:
method = getattr(t, '__array_function__', _NDARRAY_ARRAY_FUNCTION)
if method is not _NDARRAY_ARRAY_FUNCTION:
return NotImplemented

# Arguments contain no overrides, so we can safely call the
# overloaded function again.
return func(*args, **kwargs)
86 changes: 86 additions & 0 deletions numpy/core/overrides.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Preliminary implementation of NEP-18

TODO: rewrite this in C for performance.
"""
import functools
from numpy.core.multiarray import ndarray


_NDARRAY_ARRAY_FUNCTION = ndarray.__array_function__


def get_overloaded_types_and_args(relevant_args):
"""Returns a list of arguments on which to call __array_function__.

__array_function__ implementations should be called in order on the return
values from this function.
"""
# Runtime is O(num_arguments * num_unique_types)
overloaded_types = []
overloaded_args = []
for arg in relevant_args:
arg_type = type(arg)
if (arg_type not in overloaded_types and
hasattr(arg_type, '__array_function__')):

overloaded_types.append(arg_type)

# By default, insert this argument at the end, but if it is
# subclass of another argument, insert it before that argument.
# This ensures "subclasses before superclasses".
index = len(overloaded_args)
for i, old_arg in enumerate(overloaded_args):
if issubclass(arg_type, type(old_arg)):
index = i
break
overloaded_args.insert(index, arg)

# Special handling for ndarray.
overloaded_args = [
arg for arg in overloaded_args
if type(arg).__array_function__ is not _NDARRAY_ARRAY_FUNCTION
]

return overloaded_types, overloaded_args


def array_function_override(overloaded_args, func, types, args, kwargs):
"""Call __array_function__ implementations."""
for overloaded_arg in overloaded_args:
# Note that we're only calling __array_function__ on the *first*
# occurence of each argument type. This is necessary for reasonable
# performance with a possibly long list of overloaded arguments, for
# which each __array_function__ implementation might reasonably need to
# check all argument types.
result = overloaded_arg.__array_function__(func, types, args, kwargs)

if result is not NotImplemented:
return result

raise TypeError('no implementation found for {} on types that implement '
'__array_function__: {}'
.format(func, list(map(type, overloaded_args))))


def array_function_dispatch(dispatcher):
"""Wrap a function for dispatch with the __array_function__ protocol."""
def decorator(func):
@functools.wraps(func)
def new_func(*args, **kwargs):
# Collect array-like arguments.
relevant_arguments = dispatcher(*args, **kwargs)
# Check for __array_function__ methods.
types, overloaded_args = get_overloaded_types_and_args(
relevant_arguments)
# Call overrides, if necessary.
if overloaded_args:
# new_func is the function exposed in NumPy's public API. We
# use it instead of func so __array_function__ implementations
# can do equality/identity comparisons.
return array_function_override(
overloaded_args, new_func, types, args, kwargs)
else:
return func(*args, **kwargs)

return new_func
return decorator
10 changes: 10 additions & 0 deletions numpy/core/src/multiarray/methods.c
Original file line number Diff line number Diff line change
Expand Up @@ -1021,6 +1021,13 @@ array_ufunc(PyArrayObject *self, PyObject *args, PyObject *kwds)
}


static PyObject *
array_function(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
NPY_FORWARD_NDARRAY_METHOD("_array_function");
}


static PyObject *
array_copy(PyArrayObject *self, PyObject *args, PyObject *kwds)
{
Expand Down Expand Up @@ -2472,6 +2479,9 @@ NPY_NO_EXPORT PyMethodDef array_methods[] = {
{"__array_ufunc__",
(PyCFunction)array_ufunc,
METH_VARARGS | METH_KEYWORDS, NULL},
{"__array_function__",
(PyCFunction)array_function,
METH_VARARGS | METH_KEYWORDS, NULL},

#ifndef NPY_PY3K
{"__unicode__",
Expand Down
Loading