Skip to content

Add full_like for the array API #26

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion _unittests/onnx-numpy-skips.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ array_api_tests/test_creation_functions.py::test_asarray_arrays
array_api_tests/test_creation_functions.py::test_empty
array_api_tests/test_creation_functions.py::test_empty_like
array_api_tests/test_creation_functions.py::test_eye
array_api_tests/test_creation_functions.py::test_full_like
array_api_tests/test_creation_functions.py::test_linspace
array_api_tests/test_creation_functions.py::test_meshgrid
array_api_tests/test_creation_functions.py::test_zeros_like
6 changes: 3 additions & 3 deletions _unittests/ut_array_api/test_hypothesis_array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def fctonx(x, kw):


if __name__ == "__main__":
cl = TestHypothesisArraysApis()
cl.setUpClass()
cl.test_scalar_strategies()
# cl = TestHypothesisArraysApis()
# cl.setUpClass()
# cl.test_scalar_strategies()
unittest.main(verbosity=2)
20 changes: 19 additions & 1 deletion _unittests/ut_array_api/test_onnx_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,25 @@ def test_ones_like_uint16(self):
expected = np.array(1, dtype=np.uint16)
self.assertEqualArray(expected, z.numpy())

def test_full_like(self):
c = EagerTensor(np.array(False))
expected = np.full_like(c.numpy(), fill_value=False)
mat = xp.full_like(c, fill_value=False)
matnp = mat.numpy()
self.assertEqual(matnp.shape, tuple())
self.assertEqualArray(expected, matnp)

def test_full_like_mx(self):
c = EagerTensor(np.array([], dtype=np.uint8))
expected = np.full_like(c.numpy(), fill_value=0)
mat = xp.full_like(c, fill_value=0)
matnp = mat.numpy()
self.assertEqualArray(expected, matnp)


if __name__ == "__main__":
# TestOnnxNumpy().test_ones_like()
# import logging

# logging.basicConfig(level=logging.DEBUG)
# TestOnnxNumpy().test_full_like_mx()
unittest.main(verbosity=2)
7 changes: 4 additions & 3 deletions azure-pipelines.yml
Original file line number Diff line number Diff line change
Expand Up @@ -246,9 +246,10 @@ jobs:
architecture: 'x64'
- script: gcc --version
displayName: 'gcc version'
- script: |
brew update
displayName: 'brew update'
#- script: brew upgrade
# displayName: 'brew upgrade'
#- script: brew update
# displayName: 'brew update'
- script: export
displayName: 'export'
- script: gcc --version
Expand Down
1 change: 1 addition & 0 deletions onnx_array_api/array_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"empty",
"equal",
"full",
"full_like",
"isdtype",
"isfinite",
"isinf",
Expand Down
18 changes: 18 additions & 0 deletions onnx_array_api/array_api/_onnx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
abs as generic_abs,
arange as generic_arange,
full as generic_full,
full_like as generic_full_like,
ones as generic_ones,
zeros as generic_zeros,
)
Expand Down Expand Up @@ -177,6 +178,23 @@ def full(
return generic_full(shape, fill_value=value, dtype=dtype, order=order)


def full_like(
TEagerTensor: type,
x: TensorType[ElemType.allowed, "T"],
/,
fill_value: ParType[Scalar] = None,
*,
dtype: OptParType[DType] = None,
order: OptParType[str] = "C",
) -> EagerTensor[TensorType[ElemType.allowed, "TR"]]:
if dtype is None:
if isinstance(fill_value, TEagerTensor):
dtype = fill_value.dtype
elif isinstance(x, TEagerTensor):
dtype = x.dtype
return generic_full_like(x, fill_value=fill_value, dtype=dtype, order=order)


def ones(
TEagerTensor: type,
shape: EagerTensor[TensorType[ElemType.int64, "I", (None,)]],
Expand Down
47 changes: 45 additions & 2 deletions onnx_array_api/npx/npx_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,9 +275,9 @@ def astype(
if dtype is int:
to = DType(TensorProto.INT64)
elif dtype is float:
to = DType(TensorProto.FLOAT64)
to = DType(TensorProto.DOUBLE)
elif dtype is bool:
to = DType(TensorProto.FLOAT64)
to = DType(TensorProto.BOOL)
elif dtype is str:
to = DType(TensorProto.STRING)
else:
Expand Down Expand Up @@ -511,6 +511,49 @@ def full(
return var(shape, value=value, op="ConstantOfShape")


@npxapi_inline
def full_like(
x: TensorType[ElemType.allowed, "T"],
/,
*,
fill_value: ParType[Scalar] = None,
dtype: OptParType[DType] = None,
order: OptParType[str] = "C",
) -> TensorType[ElemType.numerics, "T"]:
"""
Implements :func:`numpy.zeros`.
"""
if order != "C":
raise RuntimeError(f"order={order!r} != 'C' not supported.")
if fill_value is None:
raise TypeError("fill_value cannot be None.")
if dtype is None:
if isinstance(fill_value, bool):
dtype = DType(TensorProto.BOOL)
elif isinstance(fill_value, int):
dtype = DType(TensorProto.INT64)
elif isinstance(fill_value, float):
dtype = DType(TensorProto.DOUBLE)
else:
raise TypeError(
f"Unexpected type {type(fill_value)} for fill_value={fill_value!r} "
f"and dtype={dtype!r}."
)
if isinstance(fill_value, (float, int, bool)):
value = make_tensor(
name="cst", data_type=dtype.code, dims=[1], vals=[fill_value]
)
else:
raise NotImplementedError(
f"Unexpected type ({type(fill_value)} for fill_value={fill_value!r}."
)

v = var(x.shape, value=value, op="ConstantOfShape")
if dtype is None:
return var(v, x, op="CastLike")
return v


@npxapi_inline
def floor(
x: TensorType[ElemType.numerics, "T"], /
Expand Down
15 changes: 10 additions & 5 deletions onnx_array_api/npx/npx_jit_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def info(
kwargs: Optional[Dict[str, Any]] = None,
key: Optional[Tuple[Any, ...]] = None,
onx: Optional[ModelProto] = None,
output: Optional[Any] = None,
):
"""
Logs a status.
Expand Down Expand Up @@ -93,6 +94,8 @@ def info(
"" if args is None else str(args),
"" if kwargs is None else str(kwargs),
)
if output is not None:
logger.debug("==== [%s]", output)

def status(self, me: str) -> str:
"""
Expand Down Expand Up @@ -517,7 +520,7 @@ def jit_call(self, *values, **kwargs):
f"f={self.f} from module {self.f.__module__!r} "
f"onnx=\n---\n{text}\n---\n{self.onxs[key]}"
) from e
self.info("-", "jit_call")
self.info("-", "jit_call", output=res)
return res


Expand Down Expand Up @@ -737,11 +740,13 @@ def __call__(self, *args, already_eager=False, **kwargs):
try:
res = self.f(*values, **kwargs)
except (AttributeError, TypeError) as e:
inp1 = ", ".join(map(str, map(type, args)))
inp2 = ", ".join(map(str, map(type, values)))
inp1 = ", ".join(map(str, map(lambda a: type(a).__name__, args)))
inp2 = ", ".join(map(str, map(lambda a: type(a).__name__, values)))
raise TypeError(
f"Unexpected types, input types are {inp1} "
f"and {inp2}, kwargs={kwargs}."
f"Unexpected types, input types are args=[{inp1}], "
f"values=[{inp2}], kwargs={kwargs}. "
f"(values = self._preprocess_constants(args)) "
f"args={args}, values={values}"
) from e

if isinstance(res, EagerTensor) or (
Expand Down
3 changes: 1 addition & 2 deletions onnx_array_api/npx/npx_numpy_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from onnx import ModelProto, TensorProto
from ..reference import ExtendedReferenceEvaluator
from .._helpers import np_dtype_to_tensor_dtype
from .npx_numpy_tensors_ops import ConstantOfShape
from .npx_tensors import EagerTensor, JitTensor
from .npx_types import DType, TensorType

Expand Down Expand Up @@ -36,7 +35,7 @@ def __init__(
onx: ModelProto,
f: Callable,
):
self.ref = ExtendedReferenceEvaluator(onx, new_ops=[ConstantOfShape])
self.ref = ExtendedReferenceEvaluator(onx)
self.input_names = input_names
self.tensor_class = tensor_class
self._f = f
Expand Down
2 changes: 1 addition & 1 deletion onnx_array_api/npx/npx_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __eq__(self, dt: "DType") -> bool:
if dt is bool:
return self.code_ == TensorProto.BOOL
if dt is float:
return self.code_ == TensorProto.FLOAT64
return self.code_ == TensorProto.DOUBLE
if isinstance(dt, list):
return False
if dt in ElemType.numpy_map:
Expand Down
17 changes: 17 additions & 0 deletions onnx_array_api/reference/evaluator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
from logging import getLogger
from typing import Any, Dict, List, Optional, Union
from onnx import FunctionProto, ModelProto
from onnx.defs import get_schema
from onnx.reference import ReferenceEvaluator
from onnx.reference.op_run import OpRun
from .ops.op_cast_like import CastLike_15, CastLike_19
from .ops.op_constant_of_shape import ConstantOfShape

import onnx

print(onnx.__file__)


logger = getLogger("onnx-array-api-eval")


class ExtendedReferenceEvaluator(ReferenceEvaluator):
Expand All @@ -24,6 +33,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
default_ops = [
CastLike_15,
CastLike_19,
ConstantOfShape,
]

@staticmethod
Expand Down Expand Up @@ -88,3 +98,10 @@ def __init__(
new_ops=new_ops,
**kwargs,
)

def _log(self, level: int, pattern: str, *args: List[Any]) -> None:
if level < self.verbose:
new_args = [self._log_arg(a) for a in args]
print(pattern % tuple(new_args))
else:
logger.debug(pattern, *args)
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
import numpy as np

from onnx.reference.op_run import OpRun


class ConstantOfShape(OpRun):
@staticmethod
def _process(value):
cst = value[0] if isinstance(value, np.ndarray) else value
cst = value[0] if isinstance(value, np.ndarray) and value.size > 0 else value
if isinstance(value, np.ndarray):
if len(value.shape) == 0:
cst = value
elif value.size > 0:
cst = value.ravel()[0]
else:
raise ValueError(f"Unexpected fill_value={value!r}")
if isinstance(cst, bool):
cst = np.bool_(cst)
elif isinstance(cst, int):
Expand Down