Skip to content

Commit 2fc79f6

Browse files
authored
Add full_like for the array API (#26)
* Add full_like for the array API * improvment * fix full_like
1 parent d248c16 commit 2fc79f6

File tree

12 files changed

+127
-20
lines changed

12 files changed

+127
-20
lines changed

_unittests/onnx-numpy-skips.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
66
array_api_tests/test_creation_functions.py::test_empty
77
array_api_tests/test_creation_functions.py::test_empty_like
88
array_api_tests/test_creation_functions.py::test_eye
9-
array_api_tests/test_creation_functions.py::test_full_like
109
array_api_tests/test_creation_functions.py::test_linspace
1110
array_api_tests/test_creation_functions.py::test_meshgrid
1211
array_api_tests/test_creation_functions.py::test_zeros_like

_unittests/ut_array_api/test_hypothesis_array_api.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def fctonx(x, kw):
140140

141141

142142
if __name__ == "__main__":
143-
cl = TestHypothesisArraysApis()
144-
cl.setUpClass()
145-
cl.test_scalar_strategies()
143+
# cl = TestHypothesisArraysApis()
144+
# cl.setUpClass()
145+
# cl.test_scalar_strategies()
146146
unittest.main(verbosity=2)

_unittests/ut_array_api/test_onnx_numpy.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,25 @@ def test_ones_like_uint16(self):
112112
expected = np.array(1, dtype=np.uint16)
113113
self.assertEqualArray(expected, z.numpy())
114114

115+
def test_full_like(self):
116+
c = EagerTensor(np.array(False))
117+
expected = np.full_like(c.numpy(), fill_value=False)
118+
mat = xp.full_like(c, fill_value=False)
119+
matnp = mat.numpy()
120+
self.assertEqual(matnp.shape, tuple())
121+
self.assertEqualArray(expected, matnp)
122+
123+
def test_full_like_mx(self):
124+
c = EagerTensor(np.array([], dtype=np.uint8))
125+
expected = np.full_like(c.numpy(), fill_value=0)
126+
mat = xp.full_like(c, fill_value=0)
127+
matnp = mat.numpy()
128+
self.assertEqualArray(expected, matnp)
129+
115130

116131
if __name__ == "__main__":
117-
# TestOnnxNumpy().test_ones_like()
132+
# import logging
133+
134+
# logging.basicConfig(level=logging.DEBUG)
135+
# TestOnnxNumpy().test_full_like_mx()
118136
unittest.main(verbosity=2)

azure-pipelines.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -246,9 +246,10 @@ jobs:
246246
architecture: 'x64'
247247
- script: gcc --version
248248
displayName: 'gcc version'
249-
- script: |
250-
brew update
251-
displayName: 'brew update'
249+
#- script: brew upgrade
250+
# displayName: 'brew upgrade'
251+
#- script: brew update
252+
# displayName: 'brew update'
252253
- script: export
253254
displayName: 'export'
254255
- script: gcc --version

onnx_array_api/array_api/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"empty",
1919
"equal",
2020
"full",
21+
"full_like",
2122
"isdtype",
2223
"isfinite",
2324
"isinf",

onnx_array_api/array_api/_onnx_common.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
abs as generic_abs,
2121
arange as generic_arange,
2222
full as generic_full,
23+
full_like as generic_full_like,
2324
ones as generic_ones,
2425
zeros as generic_zeros,
2526
)
@@ -177,6 +178,23 @@ def full(
177178
return generic_full(shape, fill_value=value, dtype=dtype, order=order)
178179

179180

181+
def full_like(
182+
TEagerTensor: type,
183+
x: TensorType[ElemType.allowed, "T"],
184+
/,
185+
fill_value: ParType[Scalar] = None,
186+
*,
187+
dtype: OptParType[DType] = None,
188+
order: OptParType[str] = "C",
189+
) -> EagerTensor[TensorType[ElemType.allowed, "TR"]]:
190+
if dtype is None:
191+
if isinstance(fill_value, TEagerTensor):
192+
dtype = fill_value.dtype
193+
elif isinstance(x, TEagerTensor):
194+
dtype = x.dtype
195+
return generic_full_like(x, fill_value=fill_value, dtype=dtype, order=order)
196+
197+
180198
def ones(
181199
TEagerTensor: type,
182200
shape: EagerTensor[TensorType[ElemType.int64, "I", (None,)]],

onnx_array_api/npx/npx_functions.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,9 +275,9 @@ def astype(
275275
if dtype is int:
276276
to = DType(TensorProto.INT64)
277277
elif dtype is float:
278-
to = DType(TensorProto.FLOAT64)
278+
to = DType(TensorProto.DOUBLE)
279279
elif dtype is bool:
280-
to = DType(TensorProto.FLOAT64)
280+
to = DType(TensorProto.BOOL)
281281
elif dtype is str:
282282
to = DType(TensorProto.STRING)
283283
else:
@@ -511,6 +511,49 @@ def full(
511511
return var(shape, value=value, op="ConstantOfShape")
512512

513513

514+
@npxapi_inline
515+
def full_like(
516+
x: TensorType[ElemType.allowed, "T"],
517+
/,
518+
*,
519+
fill_value: ParType[Scalar] = None,
520+
dtype: OptParType[DType] = None,
521+
order: OptParType[str] = "C",
522+
) -> TensorType[ElemType.numerics, "T"]:
523+
"""
524+
Implements :func:`numpy.zeros`.
525+
"""
526+
if order != "C":
527+
raise RuntimeError(f"order={order!r} != 'C' not supported.")
528+
if fill_value is None:
529+
raise TypeError("fill_value cannot be None.")
530+
if dtype is None:
531+
if isinstance(fill_value, bool):
532+
dtype = DType(TensorProto.BOOL)
533+
elif isinstance(fill_value, int):
534+
dtype = DType(TensorProto.INT64)
535+
elif isinstance(fill_value, float):
536+
dtype = DType(TensorProto.DOUBLE)
537+
else:
538+
raise TypeError(
539+
f"Unexpected type {type(fill_value)} for fill_value={fill_value!r} "
540+
f"and dtype={dtype!r}."
541+
)
542+
if isinstance(fill_value, (float, int, bool)):
543+
value = make_tensor(
544+
name="cst", data_type=dtype.code, dims=[1], vals=[fill_value]
545+
)
546+
else:
547+
raise NotImplementedError(
548+
f"Unexpected type ({type(fill_value)} for fill_value={fill_value!r}."
549+
)
550+
551+
v = var(x.shape, value=value, op="ConstantOfShape")
552+
if dtype is None:
553+
return var(v, x, op="CastLike")
554+
return v
555+
556+
514557
@npxapi_inline
515558
def floor(
516559
x: TensorType[ElemType.numerics, "T"], /

onnx_array_api/npx/npx_jit_eager.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def info(
5858
kwargs: Optional[Dict[str, Any]] = None,
5959
key: Optional[Tuple[Any, ...]] = None,
6060
onx: Optional[ModelProto] = None,
61+
output: Optional[Any] = None,
6162
):
6263
"""
6364
Logs a status.
@@ -93,6 +94,8 @@ def info(
9394
"" if args is None else str(args),
9495
"" if kwargs is None else str(kwargs),
9596
)
97+
if output is not None:
98+
logger.debug("==== [%s]", output)
9699

97100
def status(self, me: str) -> str:
98101
"""
@@ -517,7 +520,7 @@ def jit_call(self, *values, **kwargs):
517520
f"f={self.f} from module {self.f.__module__!r} "
518521
f"onnx=\n---\n{text}\n---\n{self.onxs[key]}"
519522
) from e
520-
self.info("-", "jit_call")
523+
self.info("-", "jit_call", output=res)
521524
return res
522525

523526

@@ -737,11 +740,13 @@ def __call__(self, *args, already_eager=False, **kwargs):
737740
try:
738741
res = self.f(*values, **kwargs)
739742
except (AttributeError, TypeError) as e:
740-
inp1 = ", ".join(map(str, map(type, args)))
741-
inp2 = ", ".join(map(str, map(type, values)))
743+
inp1 = ", ".join(map(str, map(lambda a: type(a).__name__, args)))
744+
inp2 = ", ".join(map(str, map(lambda a: type(a).__name__, values)))
742745
raise TypeError(
743-
f"Unexpected types, input types are {inp1} "
744-
f"and {inp2}, kwargs={kwargs}."
746+
f"Unexpected types, input types are args=[{inp1}], "
747+
f"values=[{inp2}], kwargs={kwargs}. "
748+
f"(values = self._preprocess_constants(args)) "
749+
f"args={args}, values={values}"
745750
) from e
746751

747752
if isinstance(res, EagerTensor) or (

onnx_array_api/npx/npx_numpy_tensors.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from onnx import ModelProto, TensorProto
55
from ..reference import ExtendedReferenceEvaluator
66
from .._helpers import np_dtype_to_tensor_dtype
7-
from .npx_numpy_tensors_ops import ConstantOfShape
87
from .npx_tensors import EagerTensor, JitTensor
98
from .npx_types import DType, TensorType
109

@@ -36,7 +35,7 @@ def __init__(
3635
onx: ModelProto,
3736
f: Callable,
3837
):
39-
self.ref = ExtendedReferenceEvaluator(onx, new_ops=[ConstantOfShape])
38+
self.ref = ExtendedReferenceEvaluator(onx)
4039
self.input_names = input_names
4140
self.tensor_class = tensor_class
4241
self._f = f

onnx_array_api/npx/npx_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __eq__(self, dt: "DType") -> bool:
6868
if dt is bool:
6969
return self.code_ == TensorProto.BOOL
7070
if dt is float:
71-
return self.code_ == TensorProto.FLOAT64
71+
return self.code_ == TensorProto.DOUBLE
7272
if isinstance(dt, list):
7373
return False
7474
if dt in ElemType.numpy_map:

onnx_array_api/reference/evaluator.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
1+
from logging import getLogger
12
from typing import Any, Dict, List, Optional, Union
23
from onnx import FunctionProto, ModelProto
34
from onnx.defs import get_schema
45
from onnx.reference import ReferenceEvaluator
56
from onnx.reference.op_run import OpRun
67
from .ops.op_cast_like import CastLike_15, CastLike_19
8+
from .ops.op_constant_of_shape import ConstantOfShape
9+
10+
import onnx
11+
12+
print(onnx.__file__)
13+
14+
15+
logger = getLogger("onnx-array-api-eval")
716

817

918
class ExtendedReferenceEvaluator(ReferenceEvaluator):
@@ -24,6 +33,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
2433
default_ops = [
2534
CastLike_15,
2635
CastLike_19,
36+
ConstantOfShape,
2737
]
2838

2939
@staticmethod
@@ -88,3 +98,10 @@ def __init__(
8898
new_ops=new_ops,
8999
**kwargs,
90100
)
101+
102+
def _log(self, level: int, pattern: str, *args: List[Any]) -> None:
103+
if level < self.verbose:
104+
new_args = [self._log_arg(a) for a in args]
105+
print(pattern % tuple(new_args))
106+
else:
107+
logger.debug(pattern, *args)

onnx_array_api/npx/npx_numpy_tensors_ops.py renamed to onnx_array_api/reference/ops/op_constant_of_shape.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
import numpy as np
2-
32
from onnx.reference.op_run import OpRun
43

54

65
class ConstantOfShape(OpRun):
76
@staticmethod
87
def _process(value):
9-
cst = value[0] if isinstance(value, np.ndarray) else value
8+
cst = value[0] if isinstance(value, np.ndarray) and value.size > 0 else value
9+
if isinstance(value, np.ndarray):
10+
if len(value.shape) == 0:
11+
cst = value
12+
elif value.size > 0:
13+
cst = value.ravel()[0]
14+
else:
15+
raise ValueError(f"Unexpected fill_value={value!r}")
1016
if isinstance(cst, bool):
1117
cst = np.bool_(cst)
1218
elif isinstance(cst, int):

0 commit comments

Comments
 (0)