Skip to content

Commit 6bea970

Browse files
authored
Add function Eye to the Array API (#29)
* Add function Eye to the Array API * remove eye * improve * fix overflow
1 parent 35cb298 commit 6bea970

File tree

10 files changed

+157
-10
lines changed

10 files changed

+157
-10
lines changed

_unittests/onnx-numpy-skips.txt

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
# API failures
22
# see https://github.com/data-apis/array-api-tests/blob/master/numpy-skips.txt
3-
array_api_tests/test_creation_functions.py::test_asarray_scalars
4-
array_api_tests/test_creation_functions.py::test_arange
3+
# uses __setitem__
54
array_api_tests/test_creation_functions.py::test_asarray_arrays
65
array_api_tests/test_creation_functions.py::test_empty
76
array_api_tests/test_creation_functions.py::test_empty_like
8-
array_api_tests/test_creation_functions.py::test_eye
97
array_api_tests/test_creation_functions.py::test_linspace
108
array_api_tests/test_creation_functions.py::test_meshgrid

_unittests/test_array_api.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_zeros_like || exit 1
2+
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py::test_eye || exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest -v -rxXfE ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1

_unittests/ut_array_api/test_hypothesis_array_api.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def sh(x):
3939

4040
class TestHypothesisArraysApis(ExtTestCase):
4141
MAX_ARRAY_SIZE = 10000
42+
SQRT_MAX_ARRAY_SIZE = int(10000**0.5)
4243
VERSION = "2021.12"
4344

4445
@classmethod
@@ -138,9 +139,80 @@ def fctonx(x, kw):
138139
fctonx()
139140
self.assertEqual(len(args_onxp), len(args_np))
140141

142+
def test_square_sizes_strategies(self):
143+
dtypes = dict(
144+
integer_dtypes=self.xps.integer_dtypes(),
145+
uinteger_dtypes=self.xps.unsigned_integer_dtypes(),
146+
floating_dtypes=self.xps.floating_dtypes(),
147+
numeric_dtypes=self.xps.numeric_dtypes(),
148+
boolean_dtypes=self.xps.boolean_dtypes(),
149+
scalar_dtypes=self.xps.scalar_dtypes(),
150+
)
151+
152+
dtypes_onnx = dict(
153+
integer_dtypes=self.onxps.integer_dtypes(),
154+
uinteger_dtypes=self.onxps.unsigned_integer_dtypes(),
155+
floating_dtypes=self.onxps.floating_dtypes(),
156+
numeric_dtypes=self.onxps.numeric_dtypes(),
157+
boolean_dtypes=self.onxps.boolean_dtypes(),
158+
scalar_dtypes=self.onxps.scalar_dtypes(),
159+
)
160+
161+
for k, vnp in dtypes.items():
162+
vonxp = dtypes_onnx[k]
163+
anp = self.xps.arrays(dtype=vnp, shape=shapes(self.xps))
164+
aonxp = self.onxps.arrays(dtype=vonxp, shape=shapes(self.onxps))
165+
self.assertNotEmpty(anp)
166+
self.assertNotEmpty(aonxp)
167+
168+
args_np = []
169+
170+
kws = array_api_kwargs(k=strategies.integers(), dtype=self.xps.numeric_dtypes())
171+
sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE)
172+
ncs = strategies.none() | sqrt_sizes
173+
174+
@given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws)
175+
def fctnp(n_rows, n_cols, kw):
176+
base = np.asarray(0)
177+
e = np.eye(n_rows, n_cols)
178+
self.assertNotEmpty(e.dtype)
179+
self.assertIsInstance(e, base.__class__)
180+
e = np.eye(n_rows, n_cols, **kw)
181+
self.assertNotEmpty(e.dtype)
182+
self.assertIsInstance(e, base.__class__)
183+
args_np.append((n_rows, n_cols, kw))
184+
185+
fctnp()
186+
self.assertEqual(len(args_np), 100)
187+
188+
args_onxp = []
189+
190+
kws = array_api_kwargs(
191+
k=strategies.integers(), dtype=self.onxps.numeric_dtypes()
192+
)
193+
sqrt_sizes = strategies.integers(0, self.SQRT_MAX_ARRAY_SIZE)
194+
ncs = strategies.none() | sqrt_sizes
195+
196+
@given(n_rows=sqrt_sizes, n_cols=ncs, kw=kws)
197+
def fctonx(n_rows, n_cols, kw):
198+
base = onxp.asarray(0)
199+
e = onxp.eye(n_rows, n_cols)
200+
self.assertIsInstance(e, base.__class__)
201+
self.assertNotEmpty(e.dtype)
202+
e = onxp.eye(n_rows, n_cols, **kw)
203+
self.assertNotEmpty(e.dtype)
204+
self.assertIsInstance(e, base.__class__)
205+
args_onxp.append((n_rows, n_cols, kw))
206+
207+
fctonx()
208+
self.assertEqual(len(args_onxp), len(args_np))
209+
141210

142211
if __name__ == "__main__":
143212
# cl = TestHypothesisArraysApis()
144213
# cl.setUpClass()
145214
# cl.test_scalar_strategies()
215+
# import logging
216+
217+
# logging.basicConfig(level=logging.DEBUG)
146218
unittest.main(verbosity=2)

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,10 +142,28 @@ def test_as_array(self):
142142
self.assertEqual(r.dtype, DType(TensorProto.UINT64))
143143
self.assertEqual(r.numpy(), 9223372036854775809)
144144

145+
def test_eye(self):
146+
nr, nc = xp.asarray(4), xp.asarray(4)
147+
expected = np.eye(nr.numpy(), nc.numpy())
148+
got = xp.eye(nr, nc)
149+
self.assertEqualArray(expected, got.numpy())
150+
151+
def test_eye_nosquare(self):
152+
nr, nc = xp.asarray(4), xp.asarray(5)
153+
expected = np.eye(nr.numpy(), nc.numpy())
154+
got = xp.eye(nr, nc)
155+
self.assertEqualArray(expected, got.numpy())
156+
157+
def test_eye_k(self):
158+
nr = xp.asarray(4)
159+
expected = np.eye(nr.numpy(), k=1)
160+
got = xp.eye(nr, k=1)
161+
self.assertEqualArray(expected, got.numpy())
162+
145163

146164
if __name__ == "__main__":
147165
# import logging
148166

149167
# logging.basicConfig(level=logging.DEBUG)
150-
# TestOnnxNumpy().test_as_array()
168+
TestOnnxNumpy().test_eye()
151169
unittest.main(verbosity=2)

onnx_array_api/array_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
"astype",
1818
"empty",
1919
"equal",
20+
"eye",
2021
"full",
2122
"full_like",
2223
"isdtype",

onnx_array_api/array_api/_onnx_common.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from typing import Any, Optional
22
import warnings
33
import numpy as np
4+
from onnx import TensorProto
45

56
with warnings.catch_warnings():
67
warnings.simplefilter("ignore")
@@ -19,6 +20,8 @@
1920
from ..npx.npx_functions import (
2021
abs as generic_abs,
2122
arange as generic_arange,
23+
copy as copy_inline,
24+
eye as generic_eye,
2225
full as generic_full,
2326
full_like as generic_full_like,
2427
ones as generic_ones,
@@ -185,6 +188,24 @@ def full(
185188
return generic_full(shape, fill_value=value, dtype=dtype, order=order)
186189

187190

191+
def eye(
192+
TEagerTensor: type,
193+
n_rows: TensorType[ElemType.int64, "I"],
194+
n_cols: OptTensorType[ElemType.int64, "I"] = None,
195+
/,
196+
*,
197+
k: ParType[int] = 0,
198+
dtype: ParType[DType] = DType(TensorProto.DOUBLE),
199+
):
200+
if isinstance(n_rows, int):
201+
n_rows = TEagerTensor(np.array(n_rows, dtype=np.int64))
202+
if n_cols is None:
203+
n_cols = n_rows
204+
elif isinstance(n_cols, int):
205+
n_cols = TEagerTensor(np.array(n_cols, dtype=np.int64))
206+
return generic_eye(n_rows, n_cols, k=k, dtype=dtype)
207+
208+
188209
def full_like(
189210
TEagerTensor: type,
190211
x: TensorType[ElemType.allowed, "T"],

onnx_array_api/npx/npx_functions.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,30 @@ def expit(
473473
return var(x, op="Sigmoid")
474474

475475

476+
@npxapi_inline
477+
def eye(
478+
n_rows: TensorType[ElemType.int64, "I"],
479+
n_cols: TensorType[ElemType.int64, "I"],
480+
/,
481+
*,
482+
k: ParType[int] = 0,
483+
dtype: ParType[DType] = DType(TensorProto.DOUBLE),
484+
):
485+
"See :func:`numpy.eye`."
486+
shape = cst(np.array([-1], dtype=np.int64))
487+
shape = var(
488+
var(n_rows, shape, op="Reshape"),
489+
var(n_cols, shape, op="Reshape"),
490+
axis=0,
491+
op="Concat",
492+
)
493+
zero = zeros(shape, dtype=dtype)
494+
res = var(zero, k=k, op="EyeLike")
495+
if dtype is not None:
496+
return var(res, to=dtype.code, op="Cast")
497+
return res
498+
499+
476500
@npxapi_inline
477501
def full(
478502
shape: TensorType[ElemType.int64, "I", (None,)],

onnx_array_api/npx/npx_graph_builder.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,11 @@ def make_node(
230230
new_kwargs[k] = v.value
231231
elif isinstance(v, DType):
232232
new_kwargs[k] = v.code
233+
elif isinstance(v, int):
234+
try:
235+
new_kwargs[k] = int(np.array(v, dtype=np.int64))
236+
except OverflowError:
237+
new_kwargs[k] = int(np.iinfo(np.int64).max)
233238
else:
234239
new_kwargs[k] = v
235240

@@ -246,6 +251,11 @@ def make_node(
246251
f"Unable to create node {op!r}, with inputs={inputs}, "
247252
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
248253
) from e
254+
except ValueError as e:
255+
raise ValueError(
256+
f"Unable to create node {op!r}, with inputs={inputs}, "
257+
f"outputs={outputs}, domain={domain!r}, new_kwargs={new_kwargs}."
258+
) from e
249259
for p in protos:
250260
node.attribute.append(p)
251261
if attribute_protos is not None:

onnx_array_api/npx/npx_jit_eager.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -510,11 +510,18 @@ def jit_call(self, *values, **kwargs):
510510
from ..plotting.text_plot import onnx_simple_text_plot
511511

512512
text = onnx_simple_text_plot(self.onxs[key])
513+
514+
def catch_len(x):
515+
try:
516+
return len(x)
517+
except TypeError:
518+
return 0
519+
513520
raise RuntimeError(
514521
f"Unable to run function for key={key!r}, "
515522
f"types={[type(x) for x in values]}, "
516523
f"dtypes={[getattr(x, 'dtype', type(x)) for x in values]}, "
517-
f"shapes={[getattr(x, 'shape', len(x)) for x in values]}, "
524+
f"shapes={[getattr(x, 'shape', catch_len(x)) for x in values]}, "
518525
f"kwargs={kwargs}, "
519526
f"self.input_to_kwargs_={self.input_to_kwargs_}, "
520527
f"f={self.f} from module {self.f.__module__!r} "

onnx_array_api/reference/evaluator.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,6 @@
77
from .ops.op_cast_like import CastLike_15, CastLike_19
88
from .ops.op_constant_of_shape import ConstantOfShape
99

10-
import onnx
11-
12-
print(onnx.__file__)
13-
1410

1511
logger = getLogger("onnx-array-api-eval")
1612

0 commit comments

Comments
 (0)