Skip to content

Commit d248c16

Browse files
authored
Better handling of float 8 in onnx_simple_text_plot (#27)
* better handling of float 8 in onnx_simple_text_plot * add function from_array_extended * doc * refactoring
1 parent c6a3718 commit d248c16

File tree

11 files changed

+170
-12
lines changed

11 files changed

+170
-12
lines changed

CHANGELOGS.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@ Change Logs
44
0.2.0
55
+++++
66

7-
* :pr:`24`: add ExtendedReferenceEvaluator to support scenario for the Array API onnx does not support
7+
* :pr:`27`: add function from_array_extended to convert
8+
an array to a TensorProto, including bfloat16 and float 8 types
9+
* :pr:`24`: add ExtendedReferenceEvaluator to support scenario
10+
for the Array API onnx does not support
811
* :pr:`22`: support OrtValue in function :func:`ort_profile`
912
* :pr:`17`: implements ArrayAPI
1013
* :pr:`3`: fixes Array API with onnxruntime and scikit-learn

_unittests/ut_plotting/test_text_plot.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -306,6 +306,50 @@ def test_function_plot(self):
306306
self.assertIn("type=? shape=?", text)
307307
self.assertIn("LinearRegression[custom]", text)
308308

309+
def test_function_plot_f8(self):
310+
new_domain = "custom"
311+
opset_imports = [make_opsetid("", 14), make_opsetid(new_domain, 1)]
312+
313+
node1 = make_node("MatMul", ["X", "A"], ["XA"])
314+
node2 = make_node("Add", ["XA", "B"], ["Y"])
315+
316+
linear_regression = make_function(
317+
new_domain, # domain name
318+
"LinearRegression", # function name
319+
["X", "A", "B"], # input names
320+
["Y"], # output names
321+
[node1, node2], # nodes
322+
opset_imports, # opsets
323+
[],
324+
) # attribute names
325+
326+
X = make_tensor_value_info("X", TensorProto.FLOAT8E4M3FN, [None, None])
327+
A = make_tensor_value_info("A", TensorProto.FLOAT8E5M2, [None, None])
328+
B = make_tensor_value_info("B", TensorProto.FLOAT8E4M3FNUZ, [None, None])
329+
Y = make_tensor_value_info("Y", TensorProto.FLOAT8E5M2FNUZ, None)
330+
331+
graph = make_graph(
332+
[
333+
make_node(
334+
"LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain
335+
),
336+
make_node("Abs", ["Y1"], ["Y"]),
337+
],
338+
"example",
339+
[X, A, B],
340+
[Y],
341+
)
342+
343+
onnx_model = make_model(
344+
graph, opset_imports=opset_imports, functions=[linear_regression]
345+
) # functions to add)
346+
347+
text = onnx_simple_text_plot(onnx_model)
348+
self.assertIn("function name=LinearRegression domain=custom", text)
349+
self.assertIn("MatMul(X, A) -> XA", text)
350+
self.assertIn("type=? shape=?", text)
351+
self.assertIn("LinearRegression[custom]", text)
352+
309353
def test_onnx_text_plot_tree_simple(self):
310354
iris = load_iris()
311355
X, y = iris.data.astype(numpy.float32), iris.target
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
import unittest
2+
import numpy as np
3+
from onnx import TensorProto
4+
from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info
5+
from onnx_array_api.ext_test_case import ExtTestCase
6+
from onnx_array_api.reference import (
7+
to_array_extended,
8+
from_array_extended,
9+
ExtendedReferenceEvaluator,
10+
)
11+
12+
13+
class TestArrayTensor(ExtTestCase):
14+
def test_from_array(self):
15+
for dt in (np.float32, np.float16, np.uint16, np.uint8):
16+
with self.subTest(dtype=dt):
17+
a = np.array([0, 1, 2], dtype=dt)
18+
t = from_array_extended(a, "a")
19+
b = to_array_extended(t)
20+
self.assertEqualArray(a, b)
21+
t2 = from_array_extended(b, "a")
22+
self.assertEqual(t.SerializeToString(), t2.SerializeToString())
23+
24+
def test_from_array_f8(self):
25+
def make_model_f8(fr, to):
26+
model = make_model(
27+
make_graph(
28+
[make_node("Cast", ["X"], ["Y"], to=to)],
29+
"cast",
30+
[make_tensor_value_info("X", fr, None)],
31+
[make_tensor_value_info("Y", to, None)],
32+
)
33+
)
34+
return model
35+
36+
for dt in (np.float32, np.float16, np.uint16, np.uint8):
37+
with self.subTest(dtype=dt):
38+
a = np.array([0, 1, 2], dtype=dt)
39+
b = from_array_extended(a, "a")
40+
for to in [
41+
TensorProto.FLOAT8E4M3FN,
42+
TensorProto.FLOAT8E4M3FNUZ,
43+
TensorProto.FLOAT8E5M2,
44+
TensorProto.FLOAT8E5M2FNUZ,
45+
TensorProto.BFLOAT16,
46+
]:
47+
with self.subTest(fr=b.data_type, to=to):
48+
model = make_model_f8(b.data_type, to)
49+
ref = ExtendedReferenceEvaluator(model)
50+
got = ref.run(None, {"X": a})[0]
51+
back = from_array_extended(got, "a")
52+
self.assertEqual(to, back.data_type)
53+
54+
55+
if __name__ == "__main__":
56+
unittest.main(verbosity=2)

onnx_array_api/npx/npx_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
from onnx import FunctionProto, ModelProto, NodeProto, TensorProto
55
from onnx.helper import make_tensor, tensor_dtype_to_np_dtype
6-
from onnx.numpy_helper import from_array
6+
from ..reference import from_array_extended as from_array
77
from .npx_constants import FUNCTION_DOMAIN
88
from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var
99
from .npx_types import (

onnx_array_api/npx/npx_graph_builder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@
2424
make_opsetid,
2525
make_tensor_value_info,
2626
)
27-
from onnx.numpy_helper import from_array
2827
from onnx.onnx_cpp2py_export.checker import ValidationError
2928
from onnx.onnx_cpp2py_export.shape_inference import InferenceError
3029
from onnx.shape_inference import infer_shapes
3130

31+
from ..reference import from_array_extended as from_array
3232
from .npx_constants import _OPSET_TO_IR_VERSION, FUNCTION_DOMAIN, ONNX_DOMAIN
3333
from .npx_function_implementation import get_function_implementation
3434
from .npx_helper import (

onnx_array_api/npx/npx_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
make_operatorsetid,
1010
make_value_info,
1111
)
12-
from onnx.numpy_helper import from_array
1312
from onnx.version_converter import convert_version
13+
from ..reference import from_array_extended as from_array
1414

1515

1616
def rename_in_onnx_graph(

onnx_array_api/plotting/_helper.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
ValueInfoProto,
1111
)
1212
from onnx.helper import tensor_dtype_to_np_dtype
13-
from onnx.numpy_helper import to_array
13+
from ..reference import to_array_extended as to_array
1414
from ..npx.npx_types import DType
1515

1616

@@ -136,12 +136,25 @@ def _get_type(obj0):
136136
return tensor_dtype_to_np_dtype(TensorProto.DOUBLE)
137137
if obj.data_type == TensorProto.INT64 and hasattr(obj, "int64_data"):
138138
return tensor_dtype_to_np_dtype(TensorProto.INT64)
139-
if obj.data_type == TensorProto.INT32 and hasattr(obj, "int32_data"):
139+
if obj.data_type in (
140+
TensorProto.INT8,
141+
TensorProto.UINT8,
142+
TensorProto.UINT16,
143+
TensorProto.INT16,
144+
TensorProto.INT32,
145+
TensorProto.FLOAT8E4M3FN,
146+
TensorProto.FLOAT8E4M3FNUZ,
147+
TensorProto.FLOAT8E5M2,
148+
TensorProto.FLOAT8E5M2FNUZ,
149+
) and hasattr(obj, "int32_data"):
140150
return tensor_dtype_to_np_dtype(TensorProto.INT32)
141151
if hasattr(obj, "raw_data") and len(obj.raw_data) > 0:
142152
arr = to_array(obj)
143153
return arr.dtype
144-
raise RuntimeError(f"Unable to guess type from {obj0!r}.")
154+
raise RuntimeError(
155+
f"Unable to guess type from obj.data_type={obj.data_type} "
156+
f"and obj={obj0!r} - {TensorProto.__dict__}."
157+
)
145158
if hasattr(obj, "type"):
146159
obj = obj.type
147160
if hasattr(obj, "tensor_type"):

onnx_array_api/plotting/dot_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33

44
from onnx import GraphProto, ModelProto
55
from onnx.helper import tensor_dtype_to_string
6-
from onnx.numpy_helper import to_array
76

7+
from ..reference import to_array_extended as to_array
88
from ._helper import Graph, _get_shape, attributes_as_dict
99

1010

onnx_array_api/plotting/text_plot.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
import pprint
22
from collections import OrderedDict
3-
43
import numpy
54
from onnx import AttributeProto
6-
from onnx.numpy_helper import to_array
7-
5+
from ..reference import to_array_extended as to_array
86
from ._helper import _get_shape, _get_type, attributes_as_dict
97

108

onnx_array_api/reference/__init__.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,45 @@
1+
from typing import Optional
2+
import numpy as np
3+
from onnx import TensorProto
4+
from onnx.numpy_helper import from_array as onnx_from_array
5+
from onnx.reference.ops.op_cast import (
6+
bfloat16,
7+
float8e4m3fn,
8+
float8e4m3fnuz,
9+
float8e5m2,
10+
float8e5m2fnuz,
11+
)
12+
from onnx.reference.op_run import to_array_extended
113
from .evaluator import ExtendedReferenceEvaluator
14+
15+
16+
def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorProto:
17+
"""
18+
Converts an array into a TensorProto.
19+
20+
:param tensor: numpy array
21+
:param name: name
22+
:return: TensorProto
23+
"""
24+
dt = tensor.dtype
25+
if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn":
26+
to = TensorProto.FLOAT8E4M3FN
27+
dt_to = np.uint8
28+
elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz":
29+
to = TensorProto.FLOAT8E4M3FNUZ
30+
dt_to = np.uint8
31+
elif dt == float8e5m2 and dt.descr[0][0] == "e5m2":
32+
to = TensorProto.FLOAT8E5M2
33+
dt_to = np.uint8
34+
elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz":
35+
to = TensorProto.FLOAT8E5M2FNUZ
36+
dt_to = np.uint8
37+
elif dt == bfloat16 and dt.descr[0][0] == "bfloat16":
38+
to = TensorProto.BFLOAT16
39+
dt_to = np.uint16
40+
else:
41+
return onnx_from_array(tensor, name)
42+
43+
t = onnx_from_array(tensor.astype(dt_to), name)
44+
t.data_type = to
45+
return t

onnx_array_api/validation/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
make_node,
1717
set_model_props,
1818
)
19-
from onnx.numpy_helper import from_array, to_array
19+
from ..reference import from_array_extended as from_array, to_array_extended as to_array
2020

2121

2222
def randomize_proto(

0 commit comments

Comments
 (0)