Skip to content

Implementing the remaining linalg tests w/ additional fixes #96

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Feb 9, 2022
19 changes: 15 additions & 4 deletions array_api_tests/hypothesis_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,10 +158,21 @@ def matrix_shapes(draw, stack_shapes=shapes()):

square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2])

finite_matrices = xps.arrays(dtype=xps.floating_dtypes(),
shape=matrix_shapes(),
elements=dict(allow_nan=False,
allow_infinity=False))
@composite
def finite_matrices(draw, shape=matrix_shapes()):
return draw(xps.arrays(dtype=xps.floating_dtypes(),
shape=shape,
elements=dict(allow_nan=False,
allow_infinity=False)))

rtol_shared_matrix_shapes = shared(matrix_shapes())
# Should we set a max_value here?
_rtol_float_kw = dict(allow_nan=False, allow_infinity=False, min_value=0)
rtols = one_of(floats(**_rtol_float_kw),
xps.arrays(dtype=xps.floating_dtypes(),
shape=rtol_shared_matrix_shapes.map(lambda shape: shape[:-2]),
elements=_rtol_float_kw))


def mutually_broadcastable_shapes(
num_shapes: int,
Expand Down
116 changes: 59 additions & 57 deletions array_api_tests/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import pytest
from hypothesis import assume, given
from hypothesis.strategies import (booleans, composite, none, tuples, integers,
shared, sampled_from, data, just)
shared, sampled_from, one_of, data, just)
from ndindex import iter_indices

from .array_helpers import assert_exactly_equal, asarray
Expand All @@ -26,7 +26,8 @@
invertible_matrices, two_mutual_arrays,
mutually_promotable_dtypes, one_d_shapes,
two_mutually_broadcastable_shapes,
SQRT_MAX_ARRAY_SIZE, finite_matrices)
SQRT_MAX_ARRAY_SIZE, finite_matrices,
rtol_shared_matrix_shapes, rtols)
from . import dtype_helpers as dh
from . import pytest_helpers as ph
from . import shape_helpers as sh
Expand All @@ -37,18 +38,17 @@

pytestmark = pytest.mark.ci



# Standin strategy for not yet implemented tests
todo = none()

def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1),
def _test_stacks(f, *args, res=None, dims=2, true_val=None,
matrix_axes=(-2, -1),
assert_equal=assert_exactly_equal, **kw):
"""
Test that f(*args, **kw) maps across stacks of matrices

dims is the number of dimensions f(*args) should have for a single n x m
matrix stack.
dims is the number of dimensions f(*args, *kw) should have for a single n
x m matrix stack.

matrix_axes are the axes along which matrices (or vectors) are stacked in
the input.
Expand All @@ -65,9 +65,13 @@ def _test_stacks(f, *args, res=None, dims=2, true_val=None, matrix_axes=(-2, -1)

shapes = [x.shape for x in args]

# Assume the result is stacked along the last 'dims' axes of matrix_axes.
# This holds for all the functions tested in this file
res_axes = matrix_axes[::-1][:dims]

for (x_idxes, (res_idx,)) in zip(
iter_indices(*shapes, skip_axes=matrix_axes),
iter_indices(res.shape, skip_axes=tuple(range(-dims, 0)))):
iter_indices(res.shape, skip_axes=res_axes)):
x_idxes = [x_idx.raw for x_idx in x_idxes]
res_idx = res_idx.raw

Expand Down Expand Up @@ -159,26 +163,18 @@ def test_cross(x1_x2_kw):
assert res.dtype == dh.result_type(x1.dtype, x2.dtype), "cross() did not return the correct dtype"
assert res.shape == shape, "cross() did not return the correct shape"

# cross is too different from other functions to use _test_stacks, and it
# is the only function that works the way it does, so it's not really
# worth generalizing _test_stacks to handle it.
a = axis if axis >= 0 else axis + len(shape)
for _idx in sh.ndindex(shape[:a] + shape[a+1:]):
idx = _idx[:a] + (slice(None),) + _idx[a:]
assert len(idx) == len(shape), "Invalid index. This indicates a bug in the test suite."
res_stack = res[idx]
x1_stack = x1[idx]
x2_stack = x2[idx]
assert x1_stack.shape == x2_stack.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
decomp_res_stack = linalg.cross(x1_stack, x2_stack)
assert_exactly_equal(res_stack, decomp_res_stack)

exact_cross = asarray([
x1_stack[1]*x2_stack[2] - x1_stack[2]*x2_stack[1],
x1_stack[2]*x2_stack[0] - x1_stack[0]*x2_stack[2],
x1_stack[0]*x2_stack[1] - x1_stack[1]*x2_stack[0],
], dtype=res.dtype)
assert_exactly_equal(res_stack, exact_cross)
def exact_cross(a, b):
assert a.shape == b.shape == (3,), "Invalid cross() stack shapes. This indicates a bug in the test suite."
return asarray([
a[1]*b[2] - a[2]*b[1],
a[2]*b[0] - a[0]*b[2],
a[0]*b[1] - a[1]*b[0],
], dtype=res.dtype)

# We don't want to pass in **kw here because that would pass axis to
# cross() on a single stack, but the axis is not meaningful on unstacked
# vectors.
_test_stacks(linalg.cross, x1, x2, dims=1, matrix_axes=(axis,), res=res, true_val=exact_cross)

@pytest.mark.xp_extension('linalg')
@given(
Expand Down Expand Up @@ -313,14 +309,30 @@ def test_matmul(x1, x2):
assert res.shape == stack_shape + (x1.shape[-2], x2.shape[-1])
_test_stacks(_array_module.matmul, x1, x2, res=res)

matrix_norm_shapes = shared(matrix_shapes())

@pytest.mark.xp_extension('linalg')
@given(
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
kw=kwargs(axis=todo, keepdims=todo, ord=todo)
x=finite_matrices(),
kw=kwargs(keepdims=booleans(),
ord=sampled_from([-float('inf'), -2, -2, 1, 2, float('inf'), 'fro', 'nuc']))
)
def test_matrix_norm(x, kw):
# res = linalg.matrix_norm(x, **kw)
pass
res = linalg.matrix_norm(x, **kw)

keepdims = kw.get('keepdims', False)
# TODO: Check that the ord values give the correct norms.
# ord = kw.get('ord', 'fro')

if keepdims:
expected_shape = x.shape[:-2] + (1, 1)
else:
expected_shape = x.shape[:-2]
assert res.shape == expected_shape, f"matrix_norm({keepdims=}) did not return the correct shape"
assert res.dtype == x.dtype, "matrix_norm() did not return the correct dtype"

_test_stacks(linalg.matrix_norm, x, **kw, dims=2 if keepdims else 0,
res=res)

matrix_power_n = shared(integers(-1000, 1000), key='matrix_power n')
@pytest.mark.xp_extension('linalg')
Expand All @@ -347,12 +359,11 @@ def test_matrix_power(x, n):

@pytest.mark.xp_extension('linalg')
@given(
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
kw=kwargs(rtol=todo)
x=finite_matrices(shape=rtol_shared_matrix_shapes),
kw=kwargs(rtol=rtols)
)
def test_matrix_rank(x, kw):
# res = linalg.matrix_rank(x, **kw)
pass
linalg.matrix_rank(x, **kw)

@given(
x=xps.arrays(dtype=dtypes, shape=matrix_shapes()),
Expand Down Expand Up @@ -397,12 +408,11 @@ def test_outer(x1, x2):

@pytest.mark.xp_extension('linalg')
@given(
x=xps.arrays(dtype=xps.floating_dtypes(), shape=shapes()),
kw=kwargs(rtol=todo)
x=finite_matrices(shape=rtol_shared_matrix_shapes),
kw=kwargs(rtol=rtols)
)
def test_pinv(x, kw):
# res = linalg.pinv(x, **kw)
pass
linalg.pinv(x, **kw)

@pytest.mark.xp_extension('linalg')
@given(
Expand Down Expand Up @@ -482,7 +492,7 @@ def solve_args():
Strategy for the x1 and x2 arguments to test_solve()

solve() takes x1, x2, where x1 is any stack of square invertible matrices
of shape (..., M, M), and x2 is either shape (..., M) or (..., M, K),
of shape (..., M, M), and x2 is either shape (M,) or (..., M, K),
where the ... parts of x1 and x2 are broadcast compatible.
"""
stack_shapes = shared(two_mutually_broadcastable_shapes)
Expand All @@ -492,30 +502,22 @@ def solve_args():
pair[0])))

@composite
def x2_shapes(draw):
end = draw(xps.array_shapes(min_dims=0, max_dims=1, min_side=0,
max_side=SQRT_MAX_ARRAY_SIZE))
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + end
def _x2_shapes(draw):
end = draw(integers(0, SQRT_MAX_ARRAY_SIZE))
return draw(stack_shapes)[1] + draw(x1).shape[-1:] + (end,)

x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes())
x2_shapes = one_of(x1.map(lambda x: (x.shape[-1],)), _x2_shapes())
x2 = xps.arrays(dtype=xps.floating_dtypes(), shape=x2_shapes)
return x1, x2

@pytest.mark.xp_extension('linalg')
@given(*solve_args())
def test_solve(x1, x2):
# TODO: solve() is currently ambiguous, in that some inputs can be
# interpreted in two different ways. For example, if x1 is shape (2, 2, 2)
# and x2 is shape (2, 2), should this be interpreted as x2 is (2,) stack
# of a (2,) vector, i.e., the result would be (2, 2, 2, 1) after
# broadcasting, or as a single stack of a 2x2 matrix, i.e., resulting in
# (2, 2, 2, 2).

# res = linalg.solve(x1, x2)
pass
linalg.solve(x1, x2)

@pytest.mark.xp_extension('linalg')
@given(
x=finite_matrices,
x=finite_matrices(),
kw=kwargs(full_matrices=booleans())
)
def test_svd(x, kw):
Expand Down Expand Up @@ -551,7 +553,7 @@ def test_svd(x, kw):

@pytest.mark.xp_extension('linalg')
@given(
x=finite_matrices,
x=finite_matrices(),
)
def test_svdvals(x):
res = linalg.svdvals(x)
Expand Down