Skip to content

Commit c82f9f3

Browse files
authored
Supports function full for the Array API (#21)
* Supports function full for the Array API * improvments * fix keys by adding types * fix unit tests * ci
1 parent ce37364 commit c82f9f3

17 files changed

+175
-44
lines changed

_unittests/onnx-numpy-skips.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
55
array_api_tests/test_creation_functions.py::test_empty
66
array_api_tests/test_creation_functions.py::test_empty_like
77
array_api_tests/test_creation_functions.py::test_eye
8-
array_api_tests/test_creation_functions.py::test_full
98
array_api_tests/test_creation_functions.py::test_full_like
109
array_api_tests/test_creation_functions.py::test_linspace
1110
array_api_tests/test_creation_functions.py::test_meshgrid

_unittests/test_array_api.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
export ARRAY_API_TESTS_MODULE=onnx_array_api.array_api.onnx_numpy
2-
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_arrays || exit 1
2+
pytest ../array-api-tests/array_api_tests/test_creation_functions.py::test_asarray_scalars || exit 1
33
# pytest ../array-api-tests/array_api_tests/test_creation_functions.py --help
44
pytest ../array-api-tests/array_api_tests/test_creation_functions.py --hypothesis-explain --skips-file=_unittests/onnx-numpy-skips.txt || exit 1

_unittests/ut_array_api/test_array_apis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class TestArraysApis(ExtTestCase):
1313
def test_zeros_numpy_1(self):
1414
c = xpn.zeros(1)
1515
d = c.numpy()
16-
self.assertEqualArray(np.array([0], dtype=np.float32), d)
16+
self.assertEqualArray(np.array([0], dtype=np.float64), d)
1717

1818
def test_zeros_ort_1(self):
1919
c = xpo.zeros(1)

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,40 @@ def test_zeros(self):
1919
a = xp.absolute(mat)
2020
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
2121

22+
def test_zeros_none(self):
23+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
24+
mat = xp.zeros(c)
25+
matnp = mat.numpy()
26+
self.assertEqual(matnp.shape, (4, 5))
27+
self.assertNotEmpty(matnp[0, 0])
28+
self.assertEqualArray(matnp, np.zeros((4, 5)))
29+
30+
def test_ones_none(self):
31+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
32+
mat = xp.ones(c)
33+
matnp = mat.numpy()
34+
self.assertEqual(matnp.shape, (4, 5))
35+
self.assertNotEmpty(matnp[0, 0])
36+
self.assertEqualArray(matnp, np.ones((4, 5)))
37+
38+
def test_full(self):
39+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
40+
mat = xp.full(c, fill_value=5, dtype=xp.int64)
41+
matnp = mat.numpy()
42+
self.assertEqual(matnp.shape, (4, 5))
43+
self.assertNotEmpty(matnp[0, 0])
44+
a = xp.absolute(mat)
45+
self.assertEqualArray(np.absolute(mat.numpy()), a.numpy())
46+
47+
def test_full_bool(self):
48+
c = EagerTensor(np.array([4, 5], dtype=np.int64))
49+
mat = xp.full(c, fill_value=False)
50+
matnp = mat.numpy()
51+
self.assertEqual(matnp.shape, (4, 5))
52+
self.assertNotEmpty(matnp[0, 0])
53+
self.assertEqualArray(matnp, np.full((4, 5), False))
54+
2255

2356
if __name__ == "__main__":
57+
TestOnnxNumpy().test_zeros_none()
2458
unittest.main(verbosity=2)

_unittests/ut_npx/test_npx.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -710,8 +710,8 @@ def impl(
710710
keys = list(sorted(f.onxs))
711711
self.assertIsInstance(f.onxs[keys[0]], ModelProto)
712712
k = keys[-1]
713-
self.assertEqual(len(k), 3)
714-
self.assertEqual(k[1:], ("axis", 0))
713+
self.assertEqual(len(k), 4)
714+
self.assertEqual(k[1:], ("axis", int, 0))
715715

716716
def test_numpy_topk(self):
717717
f = topk(Input("X"), Input("K"))
@@ -2416,6 +2416,7 @@ def compute_labels(X, centers, use_sqrt=False):
24162416
(DType(TensorProto.DOUBLE), 2),
24172417
(DType(TensorProto.DOUBLE), 2),
24182418
"use_sqrt",
2419+
bool,
24192420
True,
24202421
)
24212422
self.assertEqual(f.available_versions, [key])

azure-pipelines.yml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ jobs:
4848
vmImage: 'ubuntu-latest'
4949
strategy:
5050
matrix:
51-
Python310-Linux:
51+
Python311-Linux:
5252
python.version: '3.11'
5353
maxParallel: 3
5454

@@ -96,7 +96,7 @@ jobs:
9696
strategy:
9797
matrix:
9898
Python310-Linux:
99-
python.version: '3.11'
99+
python.version: '3.10'
100100
maxParallel: 3
101101

102102
steps:
@@ -149,7 +149,7 @@ jobs:
149149
vmImage: 'ubuntu-latest'
150150
strategy:
151151
matrix:
152-
Python310-Linux:
152+
Python311-Linux:
153153
python.version: '3.11'
154154
maxParallel: 3
155155

@@ -202,7 +202,7 @@ jobs:
202202
vmImage: 'windows-latest'
203203
strategy:
204204
matrix:
205-
Python310-Windows:
205+
Python311-Windows:
206206
python.version: '3.11'
207207
maxParallel: 3
208208

@@ -235,7 +235,7 @@ jobs:
235235
vmImage: 'macOS-latest'
236236
strategy:
237237
matrix:
238-
Python310-Mac:
238+
Python311-Mac:
239239
python.version: '3.11'
240240
maxParallel: 3
241241

onnx_array_api/_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def np_dtype_to_tensor_dtype(dtype: Any):
3939
elif dtype is int:
4040
dt = TensorProto.INT64
4141
elif dtype is float:
42-
dt = TensorProto.FLOAT64
42+
dt = TensorProto.DOUBLE
4343
else:
4444
raise KeyError(f"Unable to guess type for dtype={dtype}.")
4545
return dt

onnx_array_api/array_api/_onnx_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def template_asarray(
4444
except OverflowError:
4545
v = TEagerTensor(np.asarray(a, dtype=np.uint64))
4646
elif isinstance(a, float):
47-
v = TEagerTensor(np.array(a, dtype=np.float32))
47+
v = TEagerTensor(np.array(a, dtype=np.float64))
4848
elif isinstance(a, bool):
4949
v = TEagerTensor(np.array(a, dtype=np.bool_))
5050
elif isinstance(a, str):

onnx_array_api/array_api/onnx_numpy.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
"""
44
from typing import Any, Optional
55
import numpy as np
6-
from onnx import TensorProto
76
from ..npx.npx_functions import (
87
all,
98
abs,
@@ -16,10 +15,11 @@
1615
reshape,
1716
take,
1817
)
18+
from ..npx.npx_functions import full as generic_full
1919
from ..npx.npx_functions import ones as generic_ones
2020
from ..npx.npx_functions import zeros as generic_zeros
2121
from ..npx.npx_numpy_tensors import EagerNumpyTensor
22-
from ..npx.npx_types import DType, ElemType, TensorType, OptParType
22+
from ..npx.npx_types import DType, ElemType, TensorType, OptParType, ParType, Scalar
2323
from ._onnx_common import template_asarray
2424
from . import _finalize_array_api
2525

@@ -31,6 +31,7 @@
3131
"astype",
3232
"empty",
3333
"equal",
34+
"full",
3435
"isdtype",
3536
"isfinite",
3637
"isnan",
@@ -58,7 +59,7 @@ def asarray(
5859

5960
def ones(
6061
shape: TensorType[ElemType.int64, "I", (None,)],
61-
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
62+
dtype: OptParType[DType] = None,
6263
order: OptParType[str] = "C",
6364
) -> TensorType[ElemType.numerics, "T"]:
6465
if isinstance(shape, tuple):
@@ -76,7 +77,7 @@ def ones(
7677

7778
def empty(
7879
shape: TensorType[ElemType.int64, "I", (None,)],
79-
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
80+
dtype: OptParType[DType] = None,
8081
order: OptParType[str] = "C",
8182
) -> TensorType[ElemType.numerics, "T"]:
8283
raise RuntimeError(
@@ -87,7 +88,7 @@ def empty(
8788

8889
def zeros(
8990
shape: TensorType[ElemType.int64, "I", (None,)],
90-
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
91+
dtype: OptParType[DType] = None,
9192
order: OptParType[str] = "C",
9293
) -> TensorType[ElemType.numerics, "T"]:
9394
if isinstance(shape, tuple):
@@ -103,6 +104,32 @@ def zeros(
103104
return generic_zeros(shape, dtype=dtype, order=order)
104105

105106

107+
def full(
108+
shape: TensorType[ElemType.int64, "I", (None,)],
109+
fill_value: ParType[Scalar] = None,
110+
dtype: OptParType[DType] = None,
111+
order: OptParType[str] = "C",
112+
) -> TensorType[ElemType.numerics, "T"]:
113+
if fill_value is None:
114+
raise TypeError("fill_value cannot be None")
115+
value = fill_value
116+
if isinstance(shape, tuple):
117+
return generic_full(
118+
EagerNumpyTensor(np.array(shape, dtype=np.int64)),
119+
fill_value=value,
120+
dtype=dtype,
121+
order=order,
122+
)
123+
if isinstance(shape, int):
124+
return generic_full(
125+
EagerNumpyTensor(np.array([shape], dtype=np.int64)),
126+
fill_value=value,
127+
dtype=dtype,
128+
order=order,
129+
)
130+
return generic_full(shape, fill_value=value, dtype=dtype, order=order)
131+
132+
106133
def _finalize():
107134
"""
108135
Adds common attributes to Array API defined in this modules

onnx_array_api/npx/npx_core_api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ def wrapper(*inputs, **kwargs):
169169
new_inputs.append(i)
170170
elif isinstance(i, (int, float)):
171171
new_inputs.append(
172-
np.array([i], dtype=np.int64 if isinstance(i, int) else np.float32)
172+
np.array([i], dtype=np.int64 if isinstance(i, int) else np.float64)
173173
)
174174
elif isinstance(i, str):
175175
new_inputs.append(Input(i))

onnx_array_api/npx/npx_functions.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
SequenceType,
1616
TensorType,
1717
TupleType,
18+
Scalar,
1819
)
1920
from .npx_var import Var
2021

2122

2223
def _cstv(x):
2324
if isinstance(x, Var):
2425
return x
25-
if isinstance(x, (int, float, np.ndarray)):
26+
if isinstance(x, (int, float, bool, np.ndarray)):
2627
return cst(x)
2728
raise TypeError(f"Unexpected constant type {type(x)}.")
2829

@@ -376,6 +377,42 @@ def expit(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics
376377
return var(x, op="Sigmoid")
377378

378379

380+
@npxapi_inline
381+
def full(
382+
shape: TensorType[ElemType.int64, "I", (None,)],
383+
dtype: OptParType[DType] = None,
384+
fill_value: ParType[Scalar] = None,
385+
order: OptParType[str] = "C",
386+
) -> TensorType[ElemType.numerics, "T"]:
387+
"""
388+
Implements :func:`numpy.full`.
389+
"""
390+
if order != "C":
391+
raise RuntimeError(f"order={order!r} != 'C' not supported.")
392+
if fill_value is None:
393+
raise TypeError("fill_value cannot be None.")
394+
if dtype is None:
395+
if isinstance(fill_value, bool):
396+
dtype = DType(TensorProto.BOOL)
397+
elif isinstance(fill_value, int):
398+
dtype = DType(TensorProto.INT64)
399+
elif isinstance(fill_value, float):
400+
dtype = DType(TensorProto.DOUBLE)
401+
else:
402+
raise TypeError(
403+
f"Unexpected type {type(fill_value)} for fill_value={fill_value!r}."
404+
)
405+
if isinstance(fill_value, (float, int, bool)):
406+
value = make_tensor(
407+
name="cst", data_type=dtype.code, dims=[1], vals=[fill_value]
408+
)
409+
else:
410+
raise NotImplementedError(
411+
f"Unexpected type ({type(fill_value)} for fill_value={fill_value!r}."
412+
)
413+
return var(shape, value=value, op="ConstantOfShape")
414+
415+
379416
@npxapi_inline
380417
def floor(x: TensorType[ElemType.numerics, "T"]) -> TensorType[ElemType.numerics, "T"]:
381418
"See :func:`numpy.floor`."
@@ -464,7 +501,7 @@ def matmul(
464501
@npxapi_inline
465502
def ones(
466503
shape: TensorType[ElemType.int64, "I", (None,)],
467-
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
504+
dtype: OptParType[DType] = None,
468505
order: OptParType[str] = "C",
469506
) -> TensorType[ElemType.numerics, "T"]:
470507
"""
@@ -473,7 +510,7 @@ def ones(
473510
if order != "C":
474511
raise RuntimeError(f"order={order!r} != 'C' not supported.")
475512
if dtype is None:
476-
dtype = DType(TensorProto.FLOAT)
513+
dtype = DType(TensorProto.DOUBLE)
477514
return var(
478515
shape,
479516
value=make_tensor(name="one", data_type=dtype.code, dims=[1], vals=[1]),
@@ -674,7 +711,7 @@ def where(
674711
@npxapi_inline
675712
def zeros(
676713
shape: TensorType[ElemType.int64, "I", (None,)],
677-
dtype: OptParType[DType] = DType(TensorProto.FLOAT),
714+
dtype: OptParType[DType] = None,
678715
order: OptParType[str] = "C",
679716
) -> TensorType[ElemType.numerics, "T"]:
680717
"""
@@ -683,7 +720,7 @@ def zeros(
683720
if order != "C":
684721
raise RuntimeError(f"order={order!r} != 'C' not supported.")
685722
if dtype is None:
686-
dtype = DType(TensorProto.FLOAT)
723+
dtype = DType(TensorProto.DOUBLE)
687724
return var(
688725
shape,
689726
value=make_tensor(name="zero", data_type=dtype.code, dims=[1], vals=[0]),

onnx_array_api/npx/npx_graph_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ def to_onnx(
784784
node_inputs.append(input_name)
785785
continue
786786

787-
if isinstance(i, (int, float)):
787+
if isinstance(i, (int, float, bool)):
788788
ni = np.array(i)
789789
c = Cst(ni)
790790
input_name = self._unique(var._prefix)

0 commit comments

Comments
 (0)