Skip to content

Commit 7330b58

Browse files
committed
remove some torch issues
1 parent 395e281 commit 7330b58

File tree

1 file changed

+8
-23
lines changed

1 file changed

+8
-23
lines changed

onnx_array_api/graph_api/graph_builder.py

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from onnx import AttributeProto, FunctionProto, ModelProto, NodeProto, TensorProto
77
from onnx.reference import ReferenceEvaluator
88

9+
T = "TENSOR"
10+
911

1012
class Opset:
1113
# defined for opset >= 18
@@ -78,8 +80,8 @@ def make_node(
7880
class OptimizationOptions:
7981
def __init__(
8082
self,
81-
remove_unused: bool = False,
82-
constant_folding: bool = True,
83+
remove_unused: bool = True,
84+
constant_folding: bool = False,
8385
constant_size: int = 1024,
8486
):
8587
self.remove_unused = remove_unused
@@ -205,10 +207,6 @@ def get_constant(self, name: str) -> np.ndarray:
205207
if isinstance(value, np.ndarray):
206208
return value
207209

208-
import torch
209-
210-
if isinstance(value, torch.Tensor):
211-
return value.detach().numpy()
212210
raise TypeError(f"Unable to convert type {type(value)} into numpy array.")
213211

214212
def set_shape(self, name: str, shape: Tuple[int, ...]):
@@ -513,9 +511,7 @@ def make_nodes(
513511
return output_names[0]
514512
return output_names
515513

516-
def from_array(
517-
self, arr: "torch.Tensor", name: str = None # noqa: F821
518-
) -> TensorProto:
514+
def from_array(self, arr: T, name: str = None) -> TensorProto: # noqa: F821
519515
import sys
520516
import torch
521517

@@ -552,15 +548,8 @@ def from_array(
552548
return tensor
553549

554550
def _build_initializers(self) -> List[TensorProto]:
555-
import torch
556-
557551
res = []
558552
for k, v in sorted(self.initializers_dict.items()):
559-
if isinstance(v, torch.Tensor):
560-
# no string tensor
561-
t = self.from_array(v, name=k)
562-
res.append(t)
563-
continue
564553
if isinstance(v, np.ndarray):
565554
if self.verbose and np.prod(v.shape) > 100:
566555
print(f"[GraphBuilder] onh.from_array:{k}:{v.dtype}[{v.shape}]")
@@ -575,7 +564,7 @@ def _build_initializers(self) -> List[TensorProto]:
575564

576565
def process(
577566
self,
578-
graph_module: "torch.f.GraphModule", # noqa: F821
567+
graph_module: Any,
579568
interpreter: "Interpreter", # noqa: F821
580569
):
581570
for node in graph_module.graph.nodes:
@@ -656,19 +645,15 @@ def remove_unused(self):
656645
self.constants_ = {k: v for k, v in self.constants_.items() if k in marked}
657646
self.nodes = [node for i, node in enumerate(self.nodes) if i not in removed]
658647

659-
def _apply_transpose(
660-
self, node: NodeProto, feeds: Dict[str, "torch.Tensor"] # noqa: F821
661-
) -> "torch.Tensor": # noqa: F821
662-
import torch
663-
648+
def _apply_transpose(self, node: NodeProto, feeds: Dict[str, T]) -> T: # noqa: F821
664649
perm = None
665650
for att in node.attribute:
666651
if att.name == "perm":
667652
perm = tuple(att.ints)
668653
break
669654
assert perm, f"perm not here in node {node}"
670655
assert len(perm) == 2, f"perm={perm} is not supported with torch"
671-
return [torch.transpose(feeds[node.input[0]], *perm)]
656+
return [np.transpose(feeds[node.input[0]], *perm)]
672657

673658
def constant_folding(self):
674659
"""

0 commit comments

Comments
 (0)