Skip to content

Commit d00d823

Browse files
committed
better constant
1 parent 7330b58 commit d00d823

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

onnx_array_api/graph_api/graph_builder.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,9 @@ def _get_tensor_shape(
173173
return tuple(att.floats)
174174
if att.name == "value_ints":
175175
return (len(att.ints),)
176+
if att.name == "value":
177+
t = onh.to_array(att.t)
178+
return t.shape
176179
raise TypeError(
177180
f"Unexpected or unsupported scenario type {type(proto)}: {proto}."
178181
)
@@ -190,6 +193,9 @@ def _get_tensor_type(self, proto: Union[NodeProto, TensorProto]) -> int:
190193
return TensorProto.FLOAT
191194
if att.name == "value_ints":
192195
return TensorProto.INT64
196+
if att.name == "value":
197+
t = onh.to_array(att.t)
198+
return oh.np_dtype_to_tensor_dtype(t.dtype)
193199
raise ValueError(f"Unexpected type or value {type(proto)}: {proto}.")
194200

195201
def is_constant(self, name: str) -> bool:

0 commit comments

Comments
 (0)