diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index d0b6445..e139c0a 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator * :pr:`71`: adds tools to compare two onnx graphs * :pr:`61`: adds function to plot onnx model as graphs * :pr:`60`: supports translation of local functions diff --git a/_unittests/ut_reference/test_reference_ops.py b/_unittests/ut_reference/test_reference_ops.py index 6a44d64..9ae6fec 100644 --- a/_unittests/ut_reference/test_reference_ops.py +++ b/_unittests/ut_reference/test_reference_ops.py @@ -59,6 +59,88 @@ def test_fused_matmul11(self): got = ref.run(None, {"X": a, "Y": a}) self.assertEqualArray(a.T @ a.T, got[0]) + def test_memcpy(self): + model = make_model( + make_graph( + [ + make_node("MemcpyToHost", ["X"], ["Z"]), + make_node("MemcpyFromHost", ["X"], ["Z"]), + ], + "name", + [make_tensor_value_info("X", TensorProto.FLOAT, None)], + [make_tensor_value_info("Z", TensorProto.FLOAT, None)], + ), + opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)], + ir_version=9, + ) + a = np.arange(4).reshape(-1, 2).astype(np.float32) + ref = ExtendedReferenceEvaluator(model) + got = ref.run(None, {"X": a}) + self.assertEqualArray(a, got[0]) + + def test_quick_gelu(self): + from onnxruntime import InferenceSession + + for alpha in [0.0, 2.0]: + model = make_model( + make_graph( + [ + make_node( + "QuickGelu", + ["X"], + ["Z"], + domain="com.microsoft", + alpha=alpha, + ) + ], + "name", + [make_tensor_value_info("X", TensorProto.FLOAT, None)], + [make_tensor_value_info("Z", TensorProto.FLOAT, None)], + ), + opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)], + ir_version=9, + ) + sess = InferenceSession( + model.SerializeToString(), providers=["CPUExecutionProvider"] + ) + a = np.arange(4).reshape(-1, 2).astype(np.float32) + expected = sess.run(None, {"X": a}) + ref = ExtendedReferenceEvaluator(model) + got = ref.run(None, {"X": a}) + self.assertEqualArray(expected[0], got[0]) + + def test_scatter_elements(self): + model = make_model( + make_graph( + [ + make_node( + "ScatterElements", + ["data", "indices", "updates"], + ["Z"], + axis=3, + reduction="add", + ) + ], + "name", + [ + make_tensor_value_info("data", TensorProto.FLOAT, None), + make_tensor_value_info("indices", TensorProto.INT64, None), + make_tensor_value_info("updates", TensorProto.FLOAT, None), + ], + [make_tensor_value_info("Z", TensorProto.FLOAT, None)], + ), + opset_imports=[make_opsetid("", 18)], + ) + data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2)) + indices = np.array([[[[0]]]], dtype=np.int64) + updates = np.array([[[[1]]]], dtype=np.float32) + y = np.array( + [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32 + ).reshape((2, 2, 2, 2)) + ref = ExtendedReferenceEvaluator(model) + got = ref.run(None, {"data": data, "indices": indices, "updates": updates}) + self.assertEqualArray(y, got[0]) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_array_api/reference/evaluator.py b/onnx_array_api/reference/evaluator.py index e6ab25f..89b5a84 100644 --- a/onnx_array_api/reference/evaluator.py +++ b/onnx_array_api/reference/evaluator.py @@ -8,6 +8,9 @@ from .ops.op_concat import Concat from .ops.op_constant_of_shape import ConstantOfShape from .ops.op_fused_matmul import FusedMatMul +from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost +from .ops.op_quick_gelu import QuickGelu +from .ops.op_scatter_elements import ScatterElements logger = getLogger("onnx-array-api-eval") @@ -34,6 +37,10 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator): CastLike_19, ConstantOfShape, FusedMatMul, + MemcpyFromHost, + MemcpyToHost, + QuickGelu, + ScatterElements, ] @staticmethod diff --git a/onnx_array_api/reference/ops/op_memcpy_host.py b/onnx_array_api/reference/ops/op_memcpy_host.py new file mode 100644 index 0000000..ac365e7 --- /dev/null +++ b/onnx_array_api/reference/ops/op_memcpy_host.py @@ -0,0 +1,11 @@ +from onnx.reference.op_run import OpRun + + +class MemcpyFromHost(OpRun): + def _run(self, x): + return (x,) + + +class MemcpyToHost(OpRun): + def _run(self, x): + return (x,) diff --git a/onnx_array_api/reference/ops/op_quick_gelu.py b/onnx_array_api/reference/ops/op_quick_gelu.py new file mode 100644 index 0000000..e30c5ec --- /dev/null +++ b/onnx_array_api/reference/ops/op_quick_gelu.py @@ -0,0 +1,23 @@ +import numpy as np +from onnx.reference.op_run import OpRun + + +def sigmoid(x): # type: ignore + if x > 0: + return 1 / (1 + np.exp(-x)) + return np.exp(x) / (1 + np.exp(x)) + + +class QuickGelu(OpRun): + op_domain = "com.microsoft" + + def __init__(self, onnx_node, run_params): # type: ignore + OpRun.__init__(self, onnx_node, run_params) + self.vf = np.vectorize(sigmoid) + + def _run(self, X, alpha=1.0): + if len(X.shape) == 0: + return ((X * sigmoid(X * alpha)).astype(X.dtype),) + if X.size == 0: + return (X,) + return ((X * self.vf(X * alpha)).astype(X.dtype),) diff --git a/onnx_array_api/reference/ops/op_scatter_elements.py b/onnx_array_api/reference/ops/op_scatter_elements.py new file mode 100644 index 0000000..c4b0efa --- /dev/null +++ b/onnx_array_api/reference/ops/op_scatter_elements.py @@ -0,0 +1,98 @@ +import numpy as np + +from onnx.reference.op_run import OpRun + + +def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore + if reduction == "add": + + def f(x, y): + return x + y + + elif reduction == "min": + + def f(x, y): + return min(x, y) + + elif reduction == "max": + + def f(x, y): + return max(x, y) + + else: + + def f(x, y): + return y + + if axis < 0: + axis = data.ndim + axis + + if len(data.shape) == 1 and axis == 0: + scattered = np.copy(data) + for pos, up in zip(indices, updates): + scattered[pos] = f(scattered[pos], up) + return scattered + + if len(indices.shape) == 2: + scattered = np.copy(data) + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + scattered[indices[i, j], j] = f( + scattered[indices[i, j], j], updates[i, j] + ) + else: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + scattered[i, indices[i, j]] = f( + scattered[i, indices[i, j]], updates[i, j] + ) + return scattered + + if len(indices.shape) == 3: + scattered = np.copy(data) + if axis == 0: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in range(indices.shape[2]): + scattered[indices[i, j, k], j, k] = f( + scattered[indices[i, j, k], j, k], updates[i, j, k] + ) + elif axis == 1: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in range(indices.shape[2]): + scattered[i, indices[i, j, k], k] = f( + scattered[i, indices[i, j, k], k], updates[i, j, k] + ) + elif axis == 2: + for i in range(indices.shape[0]): + for j in range(indices.shape[1]): + for k in range(indices.shape[2]): + scattered[i, j, indices[i, j, k]] = f( + scattered[i, j, indices[i, j, k]], updates[i, j, k] + ) + return scattered + + if len(indices.shape) == 4: + scattered = np.copy(data) + if axis == 3: + for a in range(indices.shape[0]): + for i in range(indices.shape[1]): + for j in range(indices.shape[2]): + for k in range(indices.shape[3]): + scattered[a, i, j, indices[a, i, j, k]] = f( + scattered[a, i, j, indices[a, i, j, k]], + updates[a, i, j, k], + ) + return scattered + + raise RuntimeError( + f"Not implemented for indices.shape={indices.shape} and axis={axis}" + ) + + +class ScatterElements(OpRun): + def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore + res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction) + return (res,)