Skip to content

Test special cases in statistical functions #119

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 16, 2022
2 changes: 1 addition & 1 deletion array_api_tests/pytest_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def assert_0d_equals(
func_name: str, x_repr: str, x_val: Array, out_repr: str, out_val: Array, **kw
):
msg = (
f"{out_repr}={out_val}, should be {x_repr}={x_val} "
f"{out_repr}={out_val}, but should be {x_repr}={x_val} "
f"[{func_name}({fmt_kw(kw)})]"
)
if dh.is_float_dtype(out_val.dtype) and xp.isnan(out_val):
Expand Down
217 changes: 123 additions & 94 deletions array_api_tests/test_special_cases.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
"""
Tests for special cases.

Most test cases for special casing are built on runtime via the parametrized
tests test_unary/test_binary/test_iop. Most of this file consists of utility
classes and functions, all bought together to create the test cases (pytest
params), to finally be run through generalised test logic.

TODO: test integer arrays for relevant special cases
"""
# We use __future__ for forward reference type hints - this will work for even py3.8.0
# See https://stackoverflow.com/a/33533514/5193926
from __future__ import annotations
Expand Down Expand Up @@ -32,13 +42,6 @@

pytestmark = pytest.mark.ci

# The special case test casess are built on runtime via the parametrized
# test_unary and test_binary functions. Most of this file consists of utility
# classes and functions, all bought together to create the test cases (pytest
# params), to finally be run through the general test logic of either test_unary
# or test_binary.


UnaryCheck = Callable[[float], bool]
BinaryCheck = Callable[[float, float], bool]

Expand Down Expand Up @@ -170,24 +173,6 @@ def parse_value(value_str: str) -> float:
r_approx_value = re.compile(
rf"an implementation-dependent approximation to {r_code.pattern}"
)


def parse_inline_code(inline_code: str) -> float:
"""
Parses a Sphinx code string to return a float, e.g.

>>> parse_value('``0``')
0.
>>> parse_value('``NaN``')
float('nan')

"""
if m := r_code.match(inline_code):
return parse_value(m.group(1))
else:
raise ParseError(inline_code)


r_not = re.compile("not (.+)")
r_equal_to = re.compile(f"equal to {r_code.pattern}")
r_array_element = re.compile(r"``([+-]?)x([12])_i``")
Expand Down Expand Up @@ -526,6 +511,10 @@ def __repr__(self) -> str:
return f"{self.__class__.__name__}(<{self}>)"


r_case_block = re.compile(r"\*\*Special [Cc]ases\*\*\n+((?:(.*\n)+))\n+\s*Parameters")
r_case = re.compile(r"\s+-\s*(.*)\.")


class UnaryCond(Protocol):
def __call__(self, i: float) -> bool:
...
Expand All @@ -546,12 +535,34 @@ class UnaryCase(Case):


r_unary_case = re.compile("If ``x_i`` is (.+), the result is (.+)")
r_already_int_case = re.compile(
"If ``x_i`` is already integer-valued, the result is ``x_i``"
)
r_even_round_halves_case = re.compile(
"If two integers are equally close to ``x_i``, "
"the result is the even integer closest to ``x_i``"
)


def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
"""
Returns a strategy that generates float-casted integers within the bounds of dtype.
"""
for k in kw.keys():
# sanity check
assert k in ["min_value", "max_value", "exclude_min", "exclude_max"]
m, M = dh.dtype_ranges[dtype]
if "min_value" in kw.keys():
m = kw["min_value"]
if "exclude_min" in kw.keys():
m += 1
if "max_value" in kw.keys():
M = kw["max_value"]
if "exclude_max" in kw.keys():
M -= 1
return st.integers(math.ceil(m), math.floor(M)).map(float)


def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
"""
Returns a strategy that generates floats that end with .5 and are within the
Expand All @@ -568,6 +579,13 @@ def trailing_halves_from_dtype(dtype: DataType) -> st.SearchStrategy[float]:
)


already_int_case = UnaryCase(
cond_expr="x_i.is_integer()",
cond=lambda i: i.is_integer(),
cond_from_dtype=integers_from_dtype,
result_expr="x_i",
check_result=lambda i, result: i == result,
)
even_round_halves_case = UnaryCase(
cond_expr="modf(i)[0] == 0.5",
cond=lambda i: math.modf(i)[0] == 0.5,
Expand All @@ -586,7 +604,7 @@ def check_result(i: float, result: float) -> bool:
return check_result


def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
def parse_unary_case_block(case_block: str) -> List[UnaryCase]:
"""
Parses a Sphinx-formatted docstring of a unary function to return a list of
codified unary cases, e.g.
Expand Down Expand Up @@ -616,7 +634,8 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
... an array containing the square root of each element in ``x``
... '''
...
>>> unary_cases = parse_unary_docstring(sqrt.__doc__)
>>> case_block = r_case_block.search(sqrt.__doc__).group(1)
>>> unary_cases = parse_unary_case_block(case_block)
>>> for case in unary_cases:
... print(repr(case))
UnaryCase(<x_i < 0 -> NaN>)
Expand All @@ -631,19 +650,14 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
True

"""

match = r_special_cases.search(docstring)
if match is None:
return []
lines = match.group(1).split("\n")[:-1]
cases = []
for line in lines:
if m := r_case.match(line):
case = m.group(1)
else:
warn(f"line not machine-readable: '{line}'")
continue
if m := r_unary_case.search(case):
for case_m in r_case.finditer(case_block):
case_str = case_m.group(1)
if m := r_already_int_case.search(case_str):
cases.append(already_int_case)
elif m := r_even_round_halves_case.search(case_str):
cases.append(even_round_halves_case)
elif m := r_unary_case.search(case_str):
try:
cond, cond_expr_template, cond_from_dtype = parse_cond(m.group(1))
_check_result, result_expr = parse_result(m.group(2))
Expand All @@ -662,11 +676,9 @@ def parse_unary_docstring(docstring: str) -> List[UnaryCase]:
check_result=check_result,
)
cases.append(case)
elif m := r_even_round_halves_case.search(case):
cases.append(even_round_halves_case)
else:
if not r_remaining_case.search(case):
warn(f"case not machine-readable: '{case}'")
if not r_remaining_case.search(case_str):
warn(f"case not machine-readable: '{case_str}'")
return cases


Expand All @@ -690,12 +702,6 @@ class BinaryCase(Case):
check_result: BinaryResultCheck


r_special_cases = re.compile(
r"\*\*Special [Cc]ases\*\*(?:\n.*)+"
r"For floating-point operands,\n+"
r"((?:\s*-\s*.*\n)+)"
)
r_case = re.compile(r"\s+-\s*(.*)\.\n?")
r_binary_case = re.compile("If (.+), the result (.+)")
r_remaining_case = re.compile("In the remaining cases.+")
r_cond_sep = re.compile(r"(?<!``x1_i``),? and |(?<!i\.e\.), ")
Expand Down Expand Up @@ -843,25 +849,6 @@ def check_result(i1: float, i2: float, result: float) -> bool:
return check_result


def integers_from_dtype(dtype: DataType, **kw) -> st.SearchStrategy[float]:
"""
Returns a strategy that generates float-casted integers within the bounds of dtype.
"""
for k in kw.keys():
# sanity check
assert k in ["min_value", "max_value", "exclude_min", "exclude_max"]
m, M = dh.dtype_ranges[dtype]
if "min_value" in kw.keys():
m = kw["min_value"]
if "exclude_min" in kw.keys():
m += 1
if "max_value" in kw.keys():
M = kw["max_value"]
if "exclude_max" in kw.keys():
M -= 1
return st.integers(math.ceil(m), math.floor(M)).map(float)


def parse_binary_case(case_str: str) -> BinaryCase:
"""
Parses a Sphinx-formatted binary case string to return codified binary cases, e.g.
Expand All @@ -880,8 +867,7 @@ def parse_binary_case(case_str: str) -> BinaryCase:

"""
case_m = r_binary_case.match(case_str)
if case_m is None:
raise ParseError(case_str)
assert case_m is not None # sanity check
cond_strs = r_cond_sep.split(case_m.group(1))

partial_conds = []
Expand Down Expand Up @@ -1078,7 +1064,7 @@ def cond(i1: float, i2: float) -> bool:
r_redundant_case = re.compile("result.+determined by the rule already stated above")


def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
def parse_binary_case_block(case_block: str) -> List[BinaryCase]:
"""
Parses a Sphinx-formatted docstring of a binary function to return a list of
codified binary cases, e.g.
Expand Down Expand Up @@ -1108,29 +1094,21 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
... an array containing the results
... '''
...
>>> binary_cases = parse_binary_docstring(logaddexp.__doc__)
>>> case_block = r_case_block.search(logaddexp.__doc__).group(1)
>>> binary_cases = parse_binary_case_block(case_block)
>>> for case in binary_cases:
... print(repr(case))
BinaryCase(<x1_i == NaN or x2_i == NaN -> NaN>)
BinaryCase(<x1_i == +infinity and not x2_i == NaN -> +infinity>)
BinaryCase(<not x1_i == NaN and x2_i == +infinity -> +infinity>)

"""

match = r_special_cases.search(docstring)
if match is None:
return []
lines = match.group(1).split("\n")[:-1]
cases = []
for line in lines:
if m := r_case.match(line):
case_str = m.group(1)
else:
warn(f"line not machine-readable: '{line}'")
continue
for case_m in r_case.finditer(case_block):
case_str = case_m.group(1)
if r_redundant_case.search(case_str):
continue
if m := r_binary_case.match(case_str):
if r_binary_case.match(case_str):
try:
case = parse_binary_case(case_str)
cases.append(case)
Expand All @@ -1150,6 +1128,10 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
if stub.__doc__ is None:
warn(f"{stub.__name__}() stub has no docstring")
continue
if m := r_case_block.search(stub.__doc__):
case_block = m.group(1)
else:
continue
marks = []
try:
func = getattr(xp, stub.__name__)
Expand All @@ -1164,40 +1146,44 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
warn(f"{func=} has no parameters")
continue
if param_names[0] == "x":
if cases := parse_unary_docstring(stub.__doc__):
func_name_to_func = {stub.__name__: func}
if cases := parse_unary_case_block(case_block):
name_to_func = {stub.__name__: func}
if stub.__name__ in func_to_op.keys():
op_name = func_to_op[stub.__name__]
op = getattr(operator, op_name)
func_name_to_func[op_name] = op
for func_name, func in func_name_to_func.items():
name_to_func[op_name] = op
for func_name, func in name_to_func.items():
for case in cases:
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
p = pytest.param(func_name, func, case, id=id_)
unary_params.append(p)
else:
warn(f"Special cases found for {stub.__name__} but none were parsed")
continue
if len(sig.parameters) == 1:
warn(f"{func=} has one parameter '{param_names[0]}' which is not named 'x'")
continue
if param_names[0] == "x1" and param_names[1] == "x2":
if cases := parse_binary_docstring(stub.__doc__):
func_name_to_func = {stub.__name__: func}
if cases := parse_binary_case_block(case_block):
name_to_func = {stub.__name__: func}
if stub.__name__ in func_to_op.keys():
op_name = func_to_op[stub.__name__]
op = getattr(operator, op_name)
func_name_to_func[op_name] = op
# We collect inplaceoperator test cases seperately
name_to_func[op_name] = op
# We collect inplace operator test cases seperately
iop_name = "__i" + op_name[2:]
iop = getattr(operator, iop_name)
for case in cases:
id_ = f"{iop_name}({case.cond_expr}) -> {case.result_expr}"
p = pytest.param(iop_name, iop, case, id=id_)
iop_params.append(p)
for func_name, func in func_name_to_func.items():
for func_name, func in name_to_func.items():
for case in cases:
id_ = f"{func_name}({case.cond_expr}) -> {case.result_expr}"
p = pytest.param(func_name, func, case, id=id_)
binary_params.append(p)
else:
warn(f"Special cases found for {stub.__name__} but none were parsed")
continue
else:
warn(
Expand All @@ -1206,7 +1192,7 @@ def parse_binary_docstring(docstring: str) -> List[BinaryCase]:
)


# test_unary and test_binary naively generate arrays, i.e. arrays that might not
# test_{unary/binary/iop} naively generate arrays, i.e. arrays that might not
# meet the condition that is being test. We then forcibly make the array meet
# the condition by picking a random index to insert an acceptable element.
#
Expand Down Expand Up @@ -1343,3 +1329,46 @@ def test_iop(iop_name, iop, case, oneway_dtypes, oneway_shapes, data):
)
break
assume(good_example)


@pytest.mark.parametrize(
"func_name, expected",
[
("mean", float("nan")),
("prod", 1),
("std", float("nan")),
("sum", 0),
("var", float("nan")),
],
ids=["mean", "prod", "std", "sum", "var"],
)
def test_empty_arrays(func_name, expected): # TODO: parse docstrings to get expected
func = getattr(xp, func_name)
out = func(xp.asarray([], dtype=dh.default_float))
ph.assert_shape(func_name, out.shape, ()) # sanity check
msg = f"{out=!r}, but should be {expected}"
if math.isnan(expected):
assert xp.isnan(out), msg
else:
assert out == expected, msg


@pytest.mark.parametrize(
"func_name", [f.__name__ for f in category_to_funcs["statistical"]]
)
@given(
x=xps.arrays(dtype=xps.floating_dtypes(), shape=hh.shapes(min_side=1)),
data=st.data(),
)
def test_nan_propagation(func_name, x, data):
func = getattr(xp, func_name)
set_idx = data.draw(
xps.indices(x.shape, max_dims=0, allow_ellipsis=False), label="set idx"
)
x[set_idx] = float("nan")
note(f"{x=}")

out = func(x)

ph.assert_shape(func_name, out.shape, ()) # sanity check
assert xp.isnan(out), f"{out=!r}, but should be NaN"