Skip to content

Commit c6a3718

Browse files
authored
Fixes asarray for the Array API (#25)
* Fixes asarray for the Array API * move
1 parent 61eec9d commit c6a3718

File tree

7 files changed

+77
-17
lines changed

7 files changed

+77
-17
lines changed

_unittests/ut_array_api/test_hypothesis_array_api.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import warnings
33
from os import getenv
44
from functools import reduce
5+
import numpy as np
56
from operator import mul
67
from hypothesis import given
78
from onnx_array_api.ext_test_case import ExtTestCase
@@ -89,24 +90,49 @@ def test_scalar_strategies(self):
8990

9091
args_np = []
9192

93+
xx = self.xps.arrays(dtype=dtypes["integer_dtypes"], shape=shapes(self.xps))
94+
kws = array_api_kwargs(dtype=strategies.none() | self.xps.scalar_dtypes())
95+
9296
@given(
93-
x=self.xps.arrays(dtype=dtypes["integer_dtypes"], shape=shapes(self.xps)),
94-
kw=array_api_kwargs(dtype=strategies.none() | self.xps.scalar_dtypes()),
97+
x=xx,
98+
kw=kws,
9599
)
96-
def fct(x, kw):
100+
def fctnp(x, kw):
101+
asa1 = np.asarray(x)
102+
asa2 = np.asarray(x, **kw)
103+
self.assertEqual(asa1.shape, asa2.shape)
97104
args_np.append((x, kw))
98105

99-
fct()
106+
fctnp()
100107
self.assertEqual(len(args_np), 100)
101108

102109
args_onxp = []
103110

104111
xshape = shapes(self.onxps)
105112
xx = self.onxps.arrays(dtype=dtypes_onnx["integer_dtypes"], shape=xshape)
106-
kw = array_api_kwargs(dtype=strategies.none() | self.onxps.scalar_dtypes())
113+
kws = array_api_kwargs(dtype=strategies.none() | self.onxps.scalar_dtypes())
107114

108-
@given(x=xx, kw=kw)
115+
@given(x=xx, kw=kws)
109116
def fctonx(x, kw):
117+
asa = np.asarray(x.numpy())
118+
try:
119+
asp = onxp.asarray(x)
120+
except Exception as e:
121+
raise AssertionError(f"asarray fails with x={x!r}, asp={asa!r}.") from e
122+
try:
123+
self.assertEqualArray(asa, asp.numpy())
124+
except AssertionError as e:
125+
raise AssertionError(
126+
f"x={x!r} kw={kw!r} asa={asa!r}, asp={asp!r}"
127+
) from e
128+
if kw:
129+
try:
130+
asp2 = onxp.asarray(x, **kw)
131+
except Exception as e:
132+
raise AssertionError(
133+
f"asarray fails with x={x!r}, kw={kw!r}, asp={asa!r}."
134+
) from e
135+
self.assertEqual(asp.shape, asp2.shape)
110136
args_onxp.append((x, kw))
111137

112138
fctonx()

onnx_array_api/array_api/_onnx_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
from typing import Any, Optional
2+
import warnings
23
import numpy as np
4+
5+
with warnings.catch_warnings():
6+
warnings.simplefilter("ignore")
7+
from numpy.array_api._array_object import Array
38
from ..npx.npx_types import (
49
DType,
510
ElemType,
@@ -77,6 +82,10 @@ def asarray(
7782
v = TEagerTensor(np.array(a, dtype=np.str_))
7883
elif isinstance(a, list):
7984
v = TEagerTensor(np.array(a))
85+
elif isinstance(a, np.ndarray):
86+
v = TEagerTensor(a)
87+
elif isinstance(a, Array):
88+
v = TEagerTensor(np.asarray(a))
8089
else:
8190
raise RuntimeError(f"Unexpected type {type(a)} for the first input.")
8291
if dtype is not None:

onnx_array_api/npx/npx_numpy_tensors.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import warnings
12
from typing import Any, Callable, List, Optional, Tuple
23
import numpy as np
34
from onnx import ModelProto, TensorProto
@@ -221,13 +222,18 @@ def __bool__(self):
221222
if self.shape == (0,):
222223
return False
223224
if len(self.shape) != 0:
224-
raise ValueError(
225-
f"Conversion to bool only works for scalar, not for {self!r}."
225+
warnings.warn(
226+
f"Conversion to bool only works for scalar, not for {self!r}, "
227+
f"bool(...)={bool(self._tensor)}."
226228
)
229+
try:
230+
return bool(self._tensor)
231+
except ValueError as e:
232+
raise ValueError(f"Unable to convert {self} to bool.") from e
227233
return bool(self._tensor)
228234

229235
def __int__(self):
230-
"Implicit conversion to bool."
236+
"Implicit conversion to int."
231237
if len(self.shape) != 0:
232238
raise ValueError(
233239
f"Conversion to bool only works for scalar, not for {self!r}."
@@ -249,7 +255,7 @@ def __int__(self):
249255
return int(self._tensor)
250256

251257
def __float__(self):
252-
"Implicit conversion to bool."
258+
"Implicit conversion to float."
253259
if len(self.shape) != 0:
254260
raise ValueError(
255261
f"Conversion to bool only works for scalar, not for {self!r}."
@@ -261,11 +267,24 @@ def __float__(self):
261267
DType(TensorProto.BFLOAT16),
262268
}:
263269
raise TypeError(
264-
f"Conversion to int only works for float scalar, "
270+
f"Conversion to float only works for float scalar, "
265271
f"not for dtype={self.dtype}."
266272
)
267273
return float(self._tensor)
268274

275+
def __iter__(self):
276+
"""
277+
The :epkg:`Array API` does not define this function (2022/12).
278+
This method raises an exception with a better error message.
279+
"""
280+
warnings.warn(
281+
f"Iterators are not implemented in the generic case. "
282+
f"Every function using them cannot be converted into ONNX "
283+
f"(tensors - {type(self)})."
284+
)
285+
for row in self._tensor:
286+
yield self.__class__(row)
287+
269288

270289
class JitNumpyTensor(NumpyTensor, JitTensor):
271290
"""

onnx_array_api/npx/npx_tensors.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,9 @@ def __iter__(self):
3535
This method raises an exception with a better error message.
3636
"""
3737
raise ArrayApiError(
38-
"Iterators are not implemented in the generic case. "
39-
"Every function using them cannot be converted into ONNX."
38+
f"Iterators are not implemented in the generic case. "
39+
f"Every function using them cannot be converted into ONNX "
40+
f"(tensors - {type(self)})."
4041
)
4142

4243
@staticmethod

onnx_array_api/npx/npx_types.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,16 @@ def __eq__(self, dt: "DType") -> bool:
5959
return False
6060
if dt.__class__ is DType:
6161
return self.code_ == dt.code_
62-
if isinstance(dt, (int, bool, str)):
62+
if isinstance(dt, (int, bool, str, float)):
6363
return False
64+
if dt is int:
65+
return self.code_ == TensorProto.INT64
6466
if dt is str:
6567
return self.code_ == TensorProto.STRING
6668
if dt is bool:
6769
return self.code_ == TensorProto.BOOL
70+
if dt is float:
71+
return self.code_ == TensorProto.FLOAT64
6872
if isinstance(dt, list):
6973
return False
7074
if dt in ElemType.numpy_map:

onnx_array_api/npx/npx_var.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -607,8 +607,9 @@ def __iter__(self):
607607
This method raises an exception with a better error message.
608608
"""
609609
raise ArrayApiError(
610-
"Iterators are not implemented in the generic case. "
611-
"Every function using them cannot be converted into ONNX."
610+
f"Iterators are not implemented in the generic case. "
611+
f"Every function using them cannot be converted into ONNX "
612+
f"(Var - {type(self)})."
612613
)
613614

614615
def _binary_op(self, ov: "Var", op_name: str, **kwargs) -> "Var":

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ black
33
coverage
44
flake8
55
furo
6-
hypothesis<6.80.0
6+
hypothesis
77
isort
88
joblib
99
lightgbm

0 commit comments

Comments
 (0)