Skip to content

Add function Eye to the Array API #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions _unittests/onnx-numpy-skips.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
# API failures
# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
array_api_tests/test_creation_functions.py::test_asarray_scalars
array_api_tests/test_creation_functions.py::test_arange
# uses __setitem__
array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_creation_functions.py::test_empty
array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_linspace
array_api_tests/test_creation_functions.py::test_meshgrid
2 changes: 1 addition & 1 deletion _unittests/test_array_api.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros_like || exit 1
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_eye || exit 1
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1
72 changes: 72 additions & 0 deletions _unittests/ut_array_api/test_hypothesis_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def sh(x):

class TestHypothesisArraysApis(ExtTestCase):
MAX_ARRAY_SIZE = 10000
SQRT_MAX_ARRAY_SIZE = int(10000**0.5)
VERSION = "2021.12"

@classmethod
Expand Down Expand Up @@ -138,9 +139,80 @@ def fctonx(x, kw):
fctonx()
self.assertEqual(len(args_onxp), len(args_np))

def test_square_sizes_strategies(self):
dtypes = dict(
integer_dtypes=self.xps.integer_dtypes(),
uinteger_dtypes=self.xps.unsigned_integer_dtypes(),
floating_dtypes=self.xps.floating_dtypes(),
numeric_dtypes=self.xps.numeric_dtypes(),
boolean_dtypes=self.xps.boolean_dtypes(),
scalar_dtypes=self.xps.scalar_dtypes(),
)

dtypes_onnx = dict(
integer_dtypes=self.onxps.integer_dtypes(),
uinteger_dtypes=self.onxps.unsigned_integer_dtypes(),
floating_dtypes=self.onxps.floating_dtypes(),
numeric_dtypes=self.onxps.numeric_dtypes(),
boolean_dtypes=self.onxps.boolean_dtypes(),
scalar_dtypes=self.onxps.scalar_dtypes(),
)

for k, vnp in dtypes.items():
vonxp = dtypes_onnx[k]
anp = self.xps.arrays(dtype=vnp, shape=shapes(self.xps))
aonxp = self.onxps.arrays(dtype=vonxp, shape=shapes(self.onxps))
self.assertNotEmpty(anp)
self.assertNotEmpty(aonxp)

args_np = []

kws = array_api_kwargs(k=strategies.integers(), dtype=self.xps.numeric_dtypes())
sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE)
ncs = strategies.none() | sqrt_sizes

@given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws)
def fctnp(n_rows, n_cols, kw):
base = np.asarray(0)
e = np.eye(n_rows, n_cols)
self.assertNotEmpty(e.dtype)
self.assertIsInstance(e, base.__class__)
e = np.eye(n_rows, n_cols, **kw)
self.assertNotEmpty(e.dtype)
self.assertIsInstance(e, base.__class__)
args_np.append((n_rows, n_cols, kw))

fctnp()
self.assertEqual(len(args_np), 100)

args_onxp = []

kws = array_api_kwargs(
k=strategies.integers(), dtype=self.onxps.numeric_dtypes()
)
sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE)
ncs = strategies.none() | sqrt_sizes

@given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws)
def fctonx(n_rows, n_cols, kw):
base = onxp.asarray(0)
e = onxp.eye(n_rows, n_cols)
self.assertIsInstance(e, base.__class__)
self.assertNotEmpty(e.dtype)
e = onxp.eye(n_rows, n_cols, **kw)
self.assertNotEmpty(e.dtype)
self.assertIsInstance(e, base.__class__)
args_onxp.append((n_rows, n_cols, kw))

fctonx()
self.assertEqual(len(args_onxp), len(args_np))


if __name__ == "__main__":
# cl = TestHypothesisArraysApis()
# cl.setUpClass()
# cl.test_scalar_strategies()
# import logging

# logging.basicConfig(level=logging.DEBUG)
unittest.main(verbosity=2)
20 changes: 19 additions & 1 deletion _unittests/ut_array_api/test_onnx_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,28 @@ def test_as_array(self):
self.assertEqual(r.dtype, DType(TensorProto.UINT64))
self.assertEqual(r.numpy(), 9223372036854775809)

def test_eye(self):
nr, nc = xp.asarray(4), xp.asarray(4)
expected = np.eye(nr.numpy(), nc.numpy())
got = xp.eye(nr, nc)
self.assertEqualArray(expected, got.numpy())

def test_eye_nosquare(self):
nr, nc = xp.asarray(4), xp.asarray(5)
expected = np.eye(nr.numpy(), nc.numpy())
got = xp.eye(nr, nc)
self.assertEqualArray(expected, got.numpy())

def test_eye_k(self):
nr = xp.asarray(4)
expected = np.eye(nr.numpy(), k=1)
got = xp.eye(nr, k=1)
self.assertEqualArray(expected, got.numpy())


if __name__ == "__main__":
# import logging

# logging.basicConfig(level=logging.DEBUG)
# TestOnnxNumpy().test_as_array()
TestOnnxNumpy().test_eye()
unittest.main(verbosity=2)
1 change: 1 addition & 0 deletions onnx_array_api/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
"astype",
"empty",
"equal",
"eye",
"full",
"full_like",
"isdtype",
Expand Down
21 changes: 21 additions & 0 deletions onnx_array_api/array_api/_onnx_common.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Optional
import warnings
import numpy as np
from onnx import TensorProto

with warnings.catch_warnings():
warnings.simplefilter("ignore")
Expand All @@ -19,6 +20,8 @@
from ..npx.npx_functions import (
abs as generic_abs,
arange as generic_arange,
copy as copy_inline,
eye as generic_eye,
full as generic_full,
full_like as generic_full_like,
ones as generic_ones,
Expand Down Expand Up @@ -185,6 +188,24 @@ def full(
return generic_full(shape, fill_value=value, dtype=dtype, order=order)


def eye(
TEagerTensor: type,
n_rows: TensorType[ElemType.int64, "I"],
n_cols: OptTensorType[ElemType.int64, "I"] = None,
/,
*,
k: ParType[int] = 0,
dtype: ParType[DType] = DType(TensorProto.DOUBLE),
):
if isinstance(n_rows, int):
n_rows = TEagerTensor(np.array(n_rows, dtype=np.int64))
if n_cols is None:
n_cols = n_rows
elif isinstance(n_cols, int):
n_cols = TEagerTensor(np.array(n_cols, dtype=np.int64))
return generic_eye(n_rows, n_cols, k=k, dtype=dtype)


def full_like(
TEagerTensor: type,
x: TensorType[ElemType.allowed, "T"],
Expand Down
24 changes: 24 additions & 0 deletions onnx_array_api/npx/npx_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,30 @@ def expit(
return var(x, op="Sigmoid")


@npxapi_inline
def eye(
n_rows: TensorType[ElemType.int64, "I"],
n_cols: TensorType[ElemType.int64, "I"],
/,
*,
k: ParType[int] = 0,
dtype: ParType[DType] = DType(TensorProto.DOUBLE),
):
"See :func:`numpy.eye`."
shape = cst(np.array([-1], dtype=np.int64))
shape = var(
var(n_rows, shape, op="Reshape"),
var(n_cols, shape, op="Reshape"),
axis=0,
op="Concat",
)
zero = zeros(shape, dtype=dtype)
res = var(zero, k=k, op="EyeLike")
if dtype is not None:
return var(res, to=dtype.code, op="Cast")
return res


@npxapi_inline
def full(
shape: TensorType[ElemType.int64, "I", (None,)],
Expand Down
10 changes: 10 additions & 0 deletions onnx_array_api/npx/npx_graph_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,11 @@ def make_node(
new_kwargs[k] = v.value
elif isinstance(v, DType):
new_kwargs[k] = v.code
elif isinstance(v, int):
try:
new_kwargs[k] = int(np.array(v, dtype=np.int64))
except OverflowError:
new_kwargs[k] = int(np.iinfo(np.int64).max)
else:
new_kwargs[k] = v

Expand All @@ -246,6 +251,11 @@ def make_node(
f"Unable to create node {op!r}, with inputs={inputs}, "
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
) from e
except ValueError as e:
raise ValueError(
f"Unable to create node {op!r}, with inputs={inputs}, "
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
) from e
for p in protos:
node.attribute.append(p)
if attribute_protos is not None:
Expand Down
9 changes: 8 additions & 1 deletion onnx_array_api/npx/npx_jit_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,11 +510,18 @@ def jit_call(self, *values, **kwargs):
from ..plotting.text_plot import onnx_simple_text_plot

text = onnx_simple_text_plot(self.onxs[key])

def catch_len(x):
try:
return len(x)
except TypeError:
return 0

raise RuntimeError(
f"Unable to run function for key={key!r}, "
f"types={[type(x) for x in values]}, "
f"dtypes={[getattr(x, 'dtype', type(x)) for x in values]}, "
f"shapes={[getattr(x, 'shape', len(x)) for x in values]}, "
f"shapes={[getattr(x, 'shape', catch_len(x)) for x in values]}, "
f"kwargs={kwargs}, "
f"self.input_to_kwargs_={self.input_to_kwargs_}, "
f"f={self.f} from module {self.f.__module__!r} "
Expand Down
4 changes: 0 additions & 4 deletions onnx_array_api/reference/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
from .ops.op_cast_like import CastLike_15, CastLike_19
from .ops.op_constant_of_shape import ConstantOfShape

import onnx

print(onnx.__file__)


logger = getLogger("onnx-array-api-eval")

Expand Down