Skip to content

Extend ExtendedReferenceEvaluator #75

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Feb 15, 2024
Merged
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.2.0
+++++

* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
* :pr:`71`: adds tools to compare two onnx graphs
* :pr:`61`: adds function to plot onnx model as graphs
* :pr:`60`: supports translation of local functions
Expand Down
82 changes: 82 additions & 0 deletions _unittests/ut_reference/test_reference_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,88 @@ def test_fused_matmul11(self):
got = ref.run(None, {"X": a, "Y": a})
self.assertEqualArray(a.T @ a.T, got[0])

def test_memcpy(self):
model = make_model(
make_graph(
[
make_node("MemcpyToHost", ["X"], ["Z"]),
make_node("MemcpyFromHost", ["X"], ["Z"]),
],
"name",
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
ir_version=9,
)
a = np.arange(4).reshape(-1, 2).astype(np.float32)
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"X": a})
self.assertEqualArray(a, got[0])

def test_quick_gelu(self):
from onnxruntime import InferenceSession

for alpha in [0.0, 2.0]:
model = make_model(
make_graph(
[
make_node(
"QuickGelu",
["X"],
["Z"],
domain="com.microsoft",
alpha=alpha,
)
],
"name",
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
ir_version=9,
)
sess = InferenceSession(
model.SerializeToString(), providers=["CPUExecutionProvider"]
)
a = np.arange(4).reshape(-1, 2).astype(np.float32)
expected = sess.run(None, {"X": a})
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"X": a})
self.assertEqualArray(expected[0], got[0])

def test_scatter_elements(self):
model = make_model(
make_graph(
[
make_node(
"ScatterElements",
["data", "indices", "updates"],
["Z"],
axis=3,
reduction="add",
)
],
"name",
[
make_tensor_value_info("data", TensorProto.FLOAT, None),
make_tensor_value_info("indices", TensorProto.INT64, None),
make_tensor_value_info("updates", TensorProto.FLOAT, None),
],
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
),
opset_imports=[make_opsetid("", 18)],
)
data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2))
indices = np.array([[[[0]]]], dtype=np.int64)
updates = np.array([[[[1]]]], dtype=np.float32)
y = np.array(
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32
).reshape((2, 2, 2, 2))
ref = ExtendedReferenceEvaluator(model)
got = ref.run(None, {"data": data, "indices": indices, "updates": updates})
self.assertEqualArray(y, got[0])


if __name__ == "__main__":
unittest.main(verbosity=2)
7 changes: 7 additions & 0 deletions onnx_array_api/reference/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from .ops.op_concat import Concat
from .ops.op_constant_of_shape import ConstantOfShape
from .ops.op_fused_matmul import FusedMatMul
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
from .ops.op_quick_gelu import QuickGelu
from .ops.op_scatter_elements import ScatterElements


logger = getLogger("onnx-array-api-eval")
Expand All @@ -34,6 +37,10 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
CastLike_19,
ConstantOfShape,
FusedMatMul,
MemcpyFromHost,
MemcpyToHost,
QuickGelu,
ScatterElements,
]

@staticmethod
Expand Down
11 changes: 11 additions & 0 deletions onnx_array_api/reference/ops/op_memcpy_host.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
from onnx.reference.op_run import OpRun


class MemcpyFromHost(OpRun):
def _run(self, x):
return (x,)


class MemcpyToHost(OpRun):
def _run(self, x):
return (x,)
23 changes: 23 additions & 0 deletions onnx_array_api/reference/ops/op_quick_gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import numpy as np
from onnx.reference.op_run import OpRun


def sigmoid(x): # type: ignore
if x > 0:
return 1 / (1 + np.exp(-x))
return np.exp(x) / (1 + np.exp(x))


class QuickGelu(OpRun):
op_domain = "com.microsoft"

def __init__(self, onnx_node, run_params): # type: ignore
OpRun.__init__(self, onnx_node, run_params)
self.vf = np.vectorize(sigmoid)

def _run(self, X, alpha=1.0):
if len(X.shape) == 0:
return ((X * sigmoid(X * alpha)).astype(X.dtype),)
if X.size == 0:
return (X,)
return ((X * self.vf(X * alpha)).astype(X.dtype),)
98 changes: 98 additions & 0 deletions onnx_array_api/reference/ops/op_scatter_elements.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import numpy as np

from onnx.reference.op_run import OpRun


def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore
if reduction == "add":

def f(x, y):
return x + y

elif reduction == "min":

def f(x, y):
return min(x, y)

elif reduction == "max":

def f(x, y):
return max(x, y)

else:

def f(x, y):
return y

if axis < 0:
axis = data.ndim + axis

if len(data.shape) == 1 and axis == 0:
scattered = np.copy(data)
for pos, up in zip(indices, updates):
scattered[pos] = f(scattered[pos], up)
return scattered

if len(indices.shape) == 2:
scattered = np.copy(data)
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
scattered[indices[i, j], j] = f(
scattered[indices[i, j], j], updates[i, j]
)
else:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
scattered[i, indices[i, j]] = f(
scattered[i, indices[i, j]], updates[i, j]
)
return scattered

if len(indices.shape) == 3:
scattered = np.copy(data)
if axis == 0:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in range(indices.shape[2]):
scattered[indices[i, j, k], j, k] = f(
scattered[indices[i, j, k], j, k], updates[i, j, k]
)
elif axis == 1:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in range(indices.shape[2]):
scattered[i, indices[i, j, k], k] = f(
scattered[i, indices[i, j, k], k], updates[i, j, k]
)
elif axis == 2:
for i in range(indices.shape[0]):
for j in range(indices.shape[1]):
for k in range(indices.shape[2]):
scattered[i, j, indices[i, j, k]] = f(
scattered[i, j, indices[i, j, k]], updates[i, j, k]
)
return scattered

if len(indices.shape) == 4:
scattered = np.copy(data)
if axis == 3:
for a in range(indices.shape[0]):
for i in range(indices.shape[1]):
for j in range(indices.shape[2]):
for k in range(indices.shape[3]):
scattered[a, i, j, indices[a, i, j, k]] = f(
scattered[a, i, j, indices[a, i, j, k]],
updates[a, i, j, k],
)
return scattered

raise RuntimeError(
f"Not implemented for indices.shape={indices.shape} and axis={axis}"
)


class ScatterElements(OpRun):
def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore
res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction)
return (res,)