diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index b5e9d88..ec31997 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,7 +4,10 @@ Change Logs 0.2.0 +++++ -* :pr:`24`: add ExtendedReferenceEvaluator to support scenario for the Array API onnx does not support +* :pr:`27`: add function from_array_extended to convert + an array to a TensorProto, including bfloat16 and float 8 types +* :pr:`24`: add ExtendedReferenceEvaluator to support scenario + for the Array API onnx does not support * :pr:`22`: support OrtValue in function :func:`ort_profile` * :pr:`17`: implements ArrayAPI * :pr:`3`: fixes Array API with onnxruntime and scikit-learn diff --git a/_unittests/ut_plotting/test_text_plot.py b/_unittests/ut_plotting/test_text_plot.py index e36ce2c..963b5cb 100644 --- a/_unittests/ut_plotting/test_text_plot.py +++ b/_unittests/ut_plotting/test_text_plot.py @@ -306,6 +306,50 @@ def test_function_plot(self): self.assertIn("type=? shape=?", text) self.assertIn("LinearRegression[custom]", text) + def test_function_plot_f8(self): + new_domain = "custom" + opset_imports = [make_opsetid("", 14), make_opsetid(new_domain, 1)] + + node1 = make_node("MatMul", ["X", "A"], ["XA"]) + node2 = make_node("Add", ["XA", "B"], ["Y"]) + + linear_regression = make_function( + new_domain, # domain name + "LinearRegression", # function name + ["X", "A", "B"], # input names + ["Y"], # output names + [node1, node2], # nodes + opset_imports, # opsets + [], + ) # attribute names + + X = make_tensor_value_info("X", TensorProto.FLOAT8E4M3FN, [None, None]) + A = make_tensor_value_info("A", TensorProto.FLOAT8E5M2, [None, None]) + B = make_tensor_value_info("B", TensorProto.FLOAT8E4M3FNUZ, [None, None]) + Y = make_tensor_value_info("Y", TensorProto.FLOAT8E5M2FNUZ, None) + + graph = make_graph( + [ + make_node( + "LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain + ), + make_node("Abs", ["Y1"], ["Y"]), + ], + "example", + [X, A, B], + [Y], + ) + + onnx_model = make_model( + graph, opset_imports=opset_imports, functions=[linear_regression] + ) # functions to add) + + text = onnx_simple_text_plot(onnx_model) + self.assertIn("function name=LinearRegression domain=custom", text) + self.assertIn("MatMul(X, A) -> XA", text) + self.assertIn("type=? shape=?", text) + self.assertIn("LinearRegression[custom]", text) + def test_onnx_text_plot_tree_simple(self): iris = load_iris() X, y = iris.data.astype(numpy.float32), iris.target diff --git a/_unittests/ut_reference/test_array_tensor.py b/_unittests/ut_reference/test_array_tensor.py new file mode 100644 index 0000000..59fe5f1 --- /dev/null +++ b/_unittests/ut_reference/test_array_tensor.py @@ -0,0 +1,56 @@ +import unittest +import numpy as np +from onnx import TensorProto +from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.reference import ( + to_array_extended, + from_array_extended, + ExtendedReferenceEvaluator, +) + + +class TestArrayTensor(ExtTestCase): + def test_from_array(self): + for dt in (np.float32, np.float16, np.uint16, np.uint8): + with self.subTest(dtype=dt): + a = np.array([0, 1, 2], dtype=dt) + t = from_array_extended(a, "a") + b = to_array_extended(t) + self.assertEqualArray(a, b) + t2 = from_array_extended(b, "a") + self.assertEqual(t.SerializeToString(), t2.SerializeToString()) + + def test_from_array_f8(self): + def make_model_f8(fr, to): + model = make_model( + make_graph( + [make_node("Cast", ["X"], ["Y"], to=to)], + "cast", + [make_tensor_value_info("X", fr, None)], + [make_tensor_value_info("Y", to, None)], + ) + ) + return model + + for dt in (np.float32, np.float16, np.uint16, np.uint8): + with self.subTest(dtype=dt): + a = np.array([0, 1, 2], dtype=dt) + b = from_array_extended(a, "a") + for to in [ + TensorProto.FLOAT8E4M3FN, + TensorProto.FLOAT8E4M3FNUZ, + TensorProto.FLOAT8E5M2, + TensorProto.FLOAT8E5M2FNUZ, + TensorProto.BFLOAT16, + ]: + with self.subTest(fr=b.data_type, to=to): + model = make_model_f8(b.data_type, to) + ref = ExtendedReferenceEvaluator(model) + got = ref.run(None, {"X": a})[0] + back = from_array_extended(got, "a") + self.assertEqual(to, back.data_type) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/onnx_array_api/npx/npx_functions.py b/onnx_array_api/npx/npx_functions.py index 94de749..c0f0a7b 100644 --- a/onnx_array_api/npx/npx_functions.py +++ b/onnx_array_api/npx/npx_functions.py @@ -3,7 +3,7 @@ import numpy as np from onnx import FunctionProto, ModelProto, NodeProto, TensorProto from onnx.helper import make_tensor, tensor_dtype_to_np_dtype -from onnx.numpy_helper import from_array +from ..reference import from_array_extended as from_array from .npx_constants import FUNCTION_DOMAIN from .npx_core_api import cst, make_tuple, npxapi_inline, npxapi_no_inline, var from .npx_types import ( diff --git a/onnx_array_api/npx/npx_graph_builder.py b/onnx_array_api/npx/npx_graph_builder.py index 396cf39..e8e49a2 100644 --- a/onnx_array_api/npx/npx_graph_builder.py +++ b/onnx_array_api/npx/npx_graph_builder.py @@ -24,11 +24,11 @@ make_opsetid, make_tensor_value_info, ) -from onnx.numpy_helper import from_array from onnx.onnx_cpp2py_export.checker import ValidationError from onnx.onnx_cpp2py_export.shape_inference import InferenceError from onnx.shape_inference import infer_shapes +from ..reference import from_array_extended as from_array from .npx_constants import _OPSET_TO_IR_VERSION, FUNCTION_DOMAIN, ONNX_DOMAIN from .npx_function_implementation import get_function_implementation from .npx_helper import ( diff --git a/onnx_array_api/npx/npx_helper.py b/onnx_array_api/npx/npx_helper.py index 13375ab..b49ab02 100644 --- a/onnx_array_api/npx/npx_helper.py +++ b/onnx_array_api/npx/npx_helper.py @@ -9,8 +9,8 @@ make_operatorsetid, make_value_info, ) -from onnx.numpy_helper import from_array from onnx.version_converter import convert_version +from ..reference import from_array_extended as from_array def rename_in_onnx_graph( diff --git a/onnx_array_api/plotting/_helper.py b/onnx_array_api/plotting/_helper.py index 48e65d9..a4c1915 100644 --- a/onnx_array_api/plotting/_helper.py +++ b/onnx_array_api/plotting/_helper.py @@ -10,7 +10,7 @@ ValueInfoProto, ) from onnx.helper import tensor_dtype_to_np_dtype -from onnx.numpy_helper import to_array +from ..reference import to_array_extended as to_array from ..npx.npx_types import DType @@ -136,12 +136,25 @@ def _get_type(obj0): return tensor_dtype_to_np_dtype(TensorProto.DOUBLE) if obj.data_type == TensorProto.INT64 and hasattr(obj, "int64_data"): return tensor_dtype_to_np_dtype(TensorProto.INT64) - if obj.data_type == TensorProto.INT32 and hasattr(obj, "int32_data"): + if obj.data_type in ( + TensorProto.INT8, + TensorProto.UINT8, + TensorProto.UINT16, + TensorProto.INT16, + TensorProto.INT32, + TensorProto.FLOAT8E4M3FN, + TensorProto.FLOAT8E4M3FNUZ, + TensorProto.FLOAT8E5M2, + TensorProto.FLOAT8E5M2FNUZ, + ) and hasattr(obj, "int32_data"): return tensor_dtype_to_np_dtype(TensorProto.INT32) if hasattr(obj, "raw_data") and len(obj.raw_data) > 0: arr = to_array(obj) return arr.dtype - raise RuntimeError(f"Unable to guess type from {obj0!r}.") + raise RuntimeError( + f"Unable to guess type from obj.data_type={obj.data_type} " + f"and obj={obj0!r} - {TensorProto.__dict__}." + ) if hasattr(obj, "type"): obj = obj.type if hasattr(obj, "tensor_type"): diff --git a/onnx_array_api/plotting/dot_plot.py b/onnx_array_api/plotting/dot_plot.py index 2bb69d1..fd23f79 100644 --- a/onnx_array_api/plotting/dot_plot.py +++ b/onnx_array_api/plotting/dot_plot.py @@ -3,8 +3,8 @@ from onnx import GraphProto, ModelProto from onnx.helper import tensor_dtype_to_string -from onnx.numpy_helper import to_array +from ..reference import to_array_extended as to_array from ._helper import Graph, _get_shape, attributes_as_dict diff --git a/onnx_array_api/plotting/text_plot.py b/onnx_array_api/plotting/text_plot.py index dfb9be0..a570175 100644 --- a/onnx_array_api/plotting/text_plot.py +++ b/onnx_array_api/plotting/text_plot.py @@ -1,10 +1,8 @@ import pprint from collections import OrderedDict - import numpy from onnx import AttributeProto -from onnx.numpy_helper import to_array - +from ..reference import to_array_extended as to_array from ._helper import _get_shape, _get_type, attributes_as_dict diff --git a/onnx_array_api/reference/__init__.py b/onnx_array_api/reference/__init__.py index e4db27c..d8c5aa5 100644 --- a/onnx_array_api/reference/__init__.py +++ b/onnx_array_api/reference/__init__.py @@ -1 +1,45 @@ +from typing import Optional +import numpy as np +from onnx import TensorProto +from onnx.numpy_helper import from_array as onnx_from_array +from onnx.reference.ops.op_cast import ( + bfloat16, + float8e4m3fn, + float8e4m3fnuz, + float8e5m2, + float8e5m2fnuz, +) +from onnx.reference.op_run import to_array_extended from .evaluator import ExtendedReferenceEvaluator + + +def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorProto: + """ + Converts an array into a TensorProto. + + :param tensor: numpy array + :param name: name + :return: TensorProto + """ + dt = tensor.dtype + if dt == float8e4m3fn and dt.descr[0][0] == "e4m3fn": + to = TensorProto.FLOAT8E4M3FN + dt_to = np.uint8 + elif dt == float8e4m3fnuz and dt.descr[0][0] == "e4m3fnuz": + to = TensorProto.FLOAT8E4M3FNUZ + dt_to = np.uint8 + elif dt == float8e5m2 and dt.descr[0][0] == "e5m2": + to = TensorProto.FLOAT8E5M2 + dt_to = np.uint8 + elif dt == float8e5m2fnuz and dt.descr[0][0] == "e5m2fnuz": + to = TensorProto.FLOAT8E5M2FNUZ + dt_to = np.uint8 + elif dt == bfloat16 and dt.descr[0][0] == "bfloat16": + to = TensorProto.BFLOAT16 + dt_to = np.uint16 + else: + return onnx_from_array(tensor, name) + + t = onnx_from_array(tensor.astype(dt_to), name) + t.data_type = to + return t diff --git a/onnx_array_api/validation/tools.py b/onnx_array_api/validation/tools.py index 9bedef2..f4628db 100644 --- a/onnx_array_api/validation/tools.py +++ b/onnx_array_api/validation/tools.py @@ -16,7 +16,7 @@ make_node, set_model_props, ) -from onnx.numpy_helper import from_array, to_array +from ..reference import from_array_extended as from_array, to_array_extended as to_array def randomize_proto(