Skip to content

Specify/infer arguments for testing uninspectable signatures #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions array_api_tests/dtype_helpers.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import re
from collections import defaultdict
from collections.abc import Mapping
from functools import lru_cache
from inspect import signature
from typing import Any, Dict, NamedTuple, Sequence, Tuple, Union
from typing import Any, DefaultDict, NamedTuple, Sequence, Tuple, Union
from warnings import warn

from . import _array_module as xp
Expand Down Expand Up @@ -323,16 +323,14 @@ def result_type(*dtypes: DataType):
"numeric": numeric_dtypes,
"integer or boolean": bool_and_all_int_dtypes,
}
func_in_dtypes: Dict[str, Tuple[DataType, ...]] = {}
func_in_dtypes: DefaultDict[str, Tuple[DataType, ...]] = defaultdict(lambda: all_dtypes)
for name, func in name_to_func.items():
if m := r_in_dtypes.search(func.__doc__):
dtype_category = m.group(1)
if dtype_category == "numeric" and r_int_note.search(func.__doc__):
dtype_category = "floating-point"
dtypes = category_to_dtypes[dtype_category]
func_in_dtypes[name] = dtypes
elif any("x" in name for name in signature(func).parameters.keys()):
func_in_dtypes[name] = all_dtypes
# See https://github.com/data-apis/array-api/pull/413
func_in_dtypes["expm1"] = float_dtypes

Expand Down
245 changes: 151 additions & 94 deletions array_api_tests/test_signatures.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,18 @@ def squeeze(x, /, axis):
...

"""
from collections import defaultdict
from copy import copy
from inspect import Parameter, Signature, signature
from types import FunctionType
from typing import Any, Callable, Dict, List, Literal, get_args
from typing import Any, Callable, Dict, Literal, get_args
from warnings import warn

import pytest
from hypothesis import given, note, settings
from hypothesis import strategies as st
from hypothesis.strategies import DataObject

from . import dtype_helpers as dh
from . import hypothesis_helpers as hh
from . import xps
from ._array_module import _UndefinedStub
from ._array_module import mod as xp
from .stubs import array_methods, category_to_funcs, extension_to_funcs
from .typing import Array, DataType
from .stubs import array_methods, category_to_funcs, extension_to_funcs, name_to_func

pytestmark = pytest.mark.ci

Expand Down Expand Up @@ -93,24 +89,15 @@ def _test_inspectable_func(sig: Signature, stub_sig: Signature):
stub_param.name in sig.parameters.keys()
), f"Argument '{stub_param.name}' missing from signature"
param = next(p for p in params if p.name == stub_param.name)
f_stub_kind = kind_to_str[stub_param.kind]
assert param.kind in [stub_param.kind, Parameter.POSITIONAL_OR_KEYWORD,], (
f"{param.name} is a {kind_to_str[param.kind]}, "
f"but should be a {f_stub_kind} "
f"(or at least a {kind_to_str[ParameterKind.POSITIONAL_OR_KEYWORD]})"
)


def get_dtypes_strategy(func_name: str) -> st.SearchStrategy[DataType]:
if func_name in dh.func_in_dtypes.keys():
dtypes = dh.func_in_dtypes[func_name]
if hh.FILTER_UNDEFINED_DTYPES:
dtypes = [d for d in dtypes if not isinstance(d, _UndefinedStub)]
return st.sampled_from(dtypes)
else:
return xps.scalar_dtypes()


def make_pretty_func(func_name: str, *args: Any, **kwargs: Any):
def make_pretty_func(func_name: str, *args: Any, **kwargs: Any) -> str:
f_sig = f"{func_name}("
f_sig += ", ".join(str(a) for a in args)
if len(kwargs) != 0:
Expand All @@ -121,96 +108,165 @@ def make_pretty_func(func_name: str, *args: Any, **kwargs: Any):
return f_sig


matrixy_funcs: List[FunctionType] = [
*category_to_funcs["linear_algebra"],
*extension_to_funcs["linalg"],
# We test uninspectable signatures by passing valid, manually-defined arguments
# to the signature's function/method.
#
# Arguments which require use of the array module are specified as string
# expressions to be eval()'d on runtime. This is as opposed to just using the
# array module whilst setting up the tests, which is prone to halt the entire
# test suite if an array module doesn't support a given expression.
func_to_specified_args = defaultdict(
dict,
{
"permute_dims": {"axes": 0},
"reshape": {"shape": (1, 5)},
"broadcast_to": {"shape": (1, 5)},
"asarray": {"obj": [0, 1, 2, 3, 4]},
"full_like": {"fill_value": 42},
"matrix_power": {"n": 2},
},
)
func_to_specified_arg_exprs = defaultdict(
dict,
{
"stack": {"arrays": "[xp.ones((5,)), xp.ones((5,))]"},
"iinfo": {"type": "xp.int64"},
"finfo": {"type": "xp.float64"},
"cholesky": {"x": "xp.asarray([[1, 0], [0, 1]], dtype=xp.float64)"},
"inv": {"x": "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)"},
"solve": {
a: "xp.asarray([[1, 2], [3, 4]], dtype=xp.float64)" for a in ["x1", "x2"]
},
},
)
# We default most array arguments heuristically. As functions/methods work only
# with arrays of certain dtypes and shapes, we specify only supported arrays
# respective to the function.
casty_names = ["__bool__", "__int__", "__float__", "__complex__", "__index__"]
matrixy_names = [
f.__name__
for f in category_to_funcs["linear_algebra"] + extension_to_funcs["linalg"]
]
matrixy_names: List[str] = [f.__name__ for f in matrixy_funcs]
matrixy_names += ["__matmul__", "triu", "tril"]
for func_name, func in name_to_func.items():
stub_sig = signature(func)
array_argnames = set(stub_sig.parameters.keys()) & {"x", "x1", "x2", "other"}
if func in array_methods:
array_argnames.add("self")
array_argnames -= set(func_to_specified_arg_exprs[func_name].keys())
if len(array_argnames) > 0:
in_dtypes = dh.func_in_dtypes[func_name]
for dtype_name in ["float64", "bool", "int64", "complex128"]:
# We try float64 first because uninspectable numerical functions
# tend to support float inputs first-and-foremost (i.e. PyTorch)
try:
dtype = getattr(xp, dtype_name)
except AttributeError:
pass
else:
if dtype in in_dtypes:
if func_name in casty_names:
shape = ()
elif func_name in matrixy_names:
shape = (3, 3)
else:
shape = (5,)
fallback_array_expr = f"xp.ones({shape}, dtype=xp.{dtype_name})"
break
else:
warn(
f"{dh.func_in_dtypes['{func_name}']}={in_dtypes} seemingly does "
"not contain any assumed dtypes, so skipping specifying fallback array."
)
continue
for argname in array_argnames:
func_to_specified_arg_exprs[func_name][argname] = fallback_array_expr


def _test_uninspectable_func(func_name: str, func: Callable, stub_sig: Signature):
params = list(stub_sig.parameters.values())

@given(data=st.data())
@settings(max_examples=1)
def _test_uninspectable_func(
func_name: str, func: Callable, stub_sig: Signature, array: Array, data: DataObject
):
skip_msg = (
f"Signature for {func_name}() is not inspectable "
"and is too troublesome to test for otherwise"
if len(params) == 0:
func()
return

uninspectable_msg = (
f"Note {func_name}() is not inspectable so arguments are passed "
"manually to test the signature."
)
if func_name in [
# 0d shapes
"__bool__",
"__int__",
"__index__",
"__float__",
# x2 elements must be >=0
"pow",
"bitwise_left_shift",
"bitwise_right_shift",
# axis default invalid with 0d shapes
"sort",
# shape requirements
*matrixy_names,
]:
pytest.skip(skip_msg)

param_to_value: Dict[Parameter, Any] = {}
for param in stub_sig.parameters.values():
if param.kind in [Parameter.POSITIONAL_OR_KEYWORD, *VAR_KINDS]:

argname_to_arg = copy(func_to_specified_args[func_name])
argname_to_expr = func_to_specified_arg_exprs[func_name]
for argname, expr in argname_to_expr.items():
assert argname not in argname_to_arg.keys() # sanity check
try:
argname_to_arg[argname] = eval(expr, {"xp": xp})
except Exception as e:
pytest.skip(
skip_msg + f" (because '{param.name}' is a {kind_to_str[param.kind]})"
)
elif param.default != Parameter.empty:
value = param.default
elif param.name in ["x", "x1"]:
dtypes = get_dtypes_strategy(func_name)
value = data.draw(
xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label=param.name
f"Exception occured when evaluating {argname}={expr}: {e}\n"
f"{uninspectable_msg}"
)
elif param.name in ["x2", "other"]:
if param.name == "x2":
assert "x1" in [p.name for p in param_to_value.keys()] # sanity check
orig = next(v for p, v in param_to_value.items() if p.name == "x1")

posargs = []
posorkw_args = {}
kwargs = {}
no_arg_msg = (
"We have no argument specified for '{}'. Please ensure you're using "
"the latest version of array-api-tests, then open an issue if one "
f"doesn't already exist. {uninspectable_msg}"
)
for param in params:
if param.kind == Parameter.POSITIONAL_ONLY:
try:
posargs.append(argname_to_arg[param.name])
except KeyError:
pytest.skip(no_arg_msg.format(param.name))
elif param.kind == Parameter.POSITIONAL_OR_KEYWORD:
if param.default == Parameter.empty:
try:
posorkw_args[param.name] = argname_to_arg[param.name]
except KeyError:
pytest.skip(no_arg_msg.format(param.name))
else:
assert array is not None # sanity check
orig = array
value = data.draw(
xps.arrays(dtype=orig.dtype, shape=orig.shape), label=param.name
)
assert argname_to_arg[param.name]
posorkw_args[param.name] = param.default
elif param.kind == Parameter.KEYWORD_ONLY:
assert param.default != Parameter.empty # sanity check
kwargs[param.name] = param.default
else:
pytest.skip(
skip_msg + f" (because no default was found for argument {param.name})"
)
param_to_value[param] = value

args: List[Any] = [
v for p, v in param_to_value.items() if p.kind == Parameter.POSITIONAL_ONLY
]
kwargs: Dict[str, Any] = {
p.name: v for p, v in param_to_value.items() if p.kind == Parameter.KEYWORD_ONLY
}
f_func = make_pretty_func(func_name, *args, **kwargs)
note(f"trying {f_func}")
func(*args, **kwargs)
assert param.kind in VAR_KINDS # sanity check
pytest.skip(no_arg_msg.format(param.name))
if len(posorkw_args) == 0:
func(*posargs, **kwargs)
else:
posorkw_name_to_arg_pairs = list(posorkw_args.items())
for i in range(len(posorkw_name_to_arg_pairs), -1, -1):
extra_posargs = [arg for _, arg in posorkw_name_to_arg_pairs[:i]]
extra_kwargs = dict(posorkw_name_to_arg_pairs[i:])
func(*posargs, *extra_posargs, **kwargs, **extra_kwargs)


def _test_func_signature(func: Callable, stub: FunctionType, array=None):
def _test_func_signature(func: Callable, stub: FunctionType, is_method=False):
stub_sig = signature(stub)
# If testing against array, ignore 'self' arg in stub as it won't be present
# in func (which should be a method).
if array is not None:
if is_method:
stub_params = list(stub_sig.parameters.values())
del stub_params[0]
if stub_params[0].name == "self":
del stub_params[0]
stub_sig = Signature(
parameters=stub_params, return_annotation=stub_sig.return_annotation
)

try:
sig = signature(func)
_test_inspectable_func(sig, stub_sig)
except ValueError:
_test_uninspectable_func(stub.__name__, func, stub_sig, array)
try:
_test_uninspectable_func(stub.__name__, func, stub_sig)
except Exception as e:
raise e from None # suppress parent exception for cleaner pytest output
else:
_test_inspectable_func(sig, stub_sig)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -244,11 +300,12 @@ def test_extension_func_signature(extension: str, stub: FunctionType):


@pytest.mark.parametrize("stub", array_methods, ids=lambda f: f.__name__)
@given(st.data())
@settings(max_examples=1)
def test_array_method_signature(stub: FunctionType, data: DataObject):
dtypes = get_dtypes_strategy(stub.__name__)
x = data.draw(xps.arrays(dtype=dtypes, shape=hh.shapes(min_side=1)), label="x")
def test_array_method_signature(stub: FunctionType):
x_expr = func_to_specified_arg_exprs[stub.__name__]["self"]
try:
x = eval(x_expr, {"xp": xp})
except Exception as e:
pytest.skip(f"Exception occured when evaluating x={x_expr}: {e}")
assert hasattr(x, stub.__name__), f"{stub.__name__} not found in array object {x!r}"
method = getattr(x, stub.__name__)
_test_func_signature(method, stub, array=x)
_test_func_signature(method, stub, is_method=True)