6
6
from onnx import AttributeProto , FunctionProto , ModelProto , NodeProto , TensorProto
7
7
from onnx .reference import ReferenceEvaluator
8
8
9
+ T = "TENSOR"
10
+
9
11
10
12
class Opset :
11
13
# defined for opset >= 18
@@ -78,8 +80,8 @@ def make_node(
78
80
class OptimizationOptions :
79
81
def __init__ (
80
82
self ,
81
- remove_unused : bool = False ,
82
- constant_folding : bool = True ,
83
+ remove_unused : bool = True ,
84
+ constant_folding : bool = False ,
83
85
constant_size : int = 1024 ,
84
86
):
85
87
self .remove_unused = remove_unused
@@ -205,10 +207,6 @@ def get_constant(self, name: str) -> np.ndarray:
205
207
if isinstance (value , np .ndarray ):
206
208
return value
207
209
208
- import torch
209
-
210
- if isinstance (value , torch .Tensor ):
211
- return value .detach ().numpy ()
212
210
raise TypeError (f"Unable to convert type { type (value )} into numpy array." )
213
211
214
212
def set_shape (self , name : str , shape : Tuple [int , ...]):
@@ -513,9 +511,7 @@ def make_nodes(
513
511
return output_names [0 ]
514
512
return output_names
515
513
516
- def from_array (
517
- self , arr : "torch.Tensor" , name : str = None # noqa: F821
518
- ) -> TensorProto :
514
+ def from_array (self , arr : T , name : str = None ) -> TensorProto : # noqa: F821
519
515
import sys
520
516
import torch
521
517
@@ -552,15 +548,8 @@ def from_array(
552
548
return tensor
553
549
554
550
def _build_initializers (self ) -> List [TensorProto ]:
555
- import torch
556
-
557
551
res = []
558
552
for k , v in sorted (self .initializers_dict .items ()):
559
- if isinstance (v , torch .Tensor ):
560
- # no string tensor
561
- t = self .from_array (v , name = k )
562
- res .append (t )
563
- continue
564
553
if isinstance (v , np .ndarray ):
565
554
if self .verbose and np .prod (v .shape ) > 100 :
566
555
print (f"[GraphBuilder] onh.from_array:{ k } :{ v .dtype } [{ v .shape } ]" )
@@ -575,7 +564,7 @@ def _build_initializers(self) -> List[TensorProto]:
575
564
576
565
def process (
577
566
self ,
578
- graph_module : "torch.f.GraphModule" , # noqa: F821
567
+ graph_module : Any ,
579
568
interpreter : "Interpreter" , # noqa: F821
580
569
):
581
570
for node in graph_module .graph .nodes :
@@ -656,19 +645,15 @@ def remove_unused(self):
656
645
self .constants_ = {k : v for k , v in self .constants_ .items () if k in marked }
657
646
self .nodes = [node for i , node in enumerate (self .nodes ) if i not in removed ]
658
647
659
- def _apply_transpose (
660
- self , node : NodeProto , feeds : Dict [str , "torch.Tensor" ] # noqa: F821
661
- ) -> "torch.Tensor" : # noqa: F821
662
- import torch
663
-
648
+ def _apply_transpose (self , node : NodeProto , feeds : Dict [str , T ]) -> T : # noqa: F821
664
649
perm = None
665
650
for att in node .attribute :
666
651
if att .name == "perm" :
667
652
perm = tuple (att .ints )
668
653
break
669
654
assert perm , f"perm not here in node { node } "
670
655
assert len (perm ) == 2 , f"perm={ perm } is not supported with torch"
671
- return [torch .transpose (feeds [node .input [0 ]], * perm )]
656
+ return [np .transpose (feeds [node .input [0 ]], * perm )]
672
657
673
658
def constant_folding (self ):
674
659
"""
0 commit comments