27
27
from transformers .models .t5 import T5OnnxConfig
28
28
from transformers .models .xlm_roberta import XLMRobertaOnnxConfig
29
29
30
- from .. import is_tf_available , is_torch_available
30
+ from .. import is_torch_available
31
31
from ..utils import logging
32
32
from .convert import convert_pytorch , validate_model_outputs
33
33
34
34
35
- # Set of frameworks we can export from
36
- FRAMEWORK_NAME_PT = "pytorch"
37
- FRAMEWORK_NAME_TF = "tensorflow"
38
- FRAMEWORK_CHOICES = {FRAMEWORK_NAME_PT , FRAMEWORK_NAME_PT }
39
-
40
35
if is_torch_available ():
41
36
from transformers import AutoModel , PreTrainedModel
42
37
43
38
FEATURES_TO_AUTOMODELS = {
44
39
"default" : AutoModel ,
45
40
}
46
41
47
- if is_tf_available ():
48
- from transformers import TFAutoModel , TFPreTrainedModel
49
-
50
- FEATURES_TO_TF_AUTOMODELS = {
51
- "default" : TFAutoModel ,
52
- }
53
42
54
43
# Set of model topologies we support associated to the features supported by each topology and the factory
55
44
SUPPORTED_MODEL_KIND = {
65
54
}
66
55
67
56
68
- def get_model_from_framework_and_features ( framework : str , features : str , model : str ):
57
+ def get_model_from_features ( features : str , model : str ):
69
58
"""
70
59
Attempt to retrieve a model from a model's name and the features to be enabled.
71
60
72
61
Args:
73
- framework: The framework we are targeting
74
62
features: The features required
75
63
model: The name of the model to export
76
64
77
65
Returns:
78
66
79
67
"""
80
- if framework == FRAMEWORK_NAME_PT :
81
- if features not in FEATURES_TO_AUTOMODELS :
82
- raise KeyError (
83
- f"Unknown feature: { features } ." f"Possible values are { list (FEATURES_TO_AUTOMODELS .values ())} "
68
+ if features not in FEATURES_TO_AUTOMODELS :
69
+ raise KeyError (
70
+ f"Unknown feature: { features } ." f"Possible values are { list (FEATURES_TO_AUTOMODELS .values ())} "
84
71
)
85
72
86
- return FEATURES_TO_AUTOMODELS [features ].from_pretrained (model )
87
- elif framework == FRAMEWORK_NAME_TF :
88
- if features not in FEATURES_TO_TF_AUTOMODELS :
89
- raise KeyError (
90
- f"Unknown feature: { features } ." f"Possible values are { list (FEATURES_TO_AUTOMODELS .values ())} "
91
- )
92
- return FEATURES_TO_TF_AUTOMODELS [features ].from_pretrained (model )
93
- else :
94
- raise ValueError (f"Unknown framework: { framework } " )
73
+ return FEATURES_TO_AUTOMODELS [features ].from_pretrained (model )
95
74
96
75
97
76
def check_supported_model_or_raise (
@@ -129,13 +108,6 @@ def check_supported_model_or_raise(
129
108
def main ():
130
109
parser = ArgumentParser ("Hugging Face ONNX Exporter tool" )
131
110
parser .add_argument ("-m" , "--model" , type = str , required = True , help = "Model's name of path on disk to load." )
132
- parser .add_argument (
133
- "-f" ,
134
- "--framework" ,
135
- choices = FRAMEWORK_CHOICES ,
136
- required = True ,
137
- help = f"Framework to use when exporting. Possible values are: { FRAMEWORK_CHOICES } " ,
138
- )
139
111
parser .add_argument (
140
112
"--features" ,
141
113
choices = ["default" ],
@@ -157,11 +129,9 @@ def main():
157
129
if not args .output .parent .exists ():
158
130
args .output .parent .mkdir (parents = True )
159
131
160
- logger .info (f"About to export model: { args .model } using framework: { args .framework } " )
161
-
162
132
# Allocate the model
163
133
tokenizer = AutoTokenizer .from_pretrained (args .model )
164
- model = get_model_from_framework_and_features ( args . framework , args .features , args .model )
134
+ model = get_model_from_features ( args .features , args .model )
165
135
model_kind , model_onnx_config = check_supported_model_or_raise (model , features = args .features )
166
136
onnx_config = model_onnx_config (model .config )
167
137
@@ -172,10 +142,7 @@ def main():
172
142
f"At least { onnx_config .default_onnx_opset } is required."
173
143
)
174
144
175
- if args .framework == FRAMEWORK_NAME_PT :
176
- onnx_inputs , onnx_outputs = convert_pytorch (tokenizer , model , onnx_config , args .opset , args .output )
177
- else :
178
- raise NotImplementedError ()
145
+ onnx_inputs , onnx_outputs = convert_pytorch (tokenizer , model , onnx_config , args .opset , args .output )
179
146
180
147
validate_model_outputs (onnx_config , tokenizer , model , args .output , onnx_outputs , args .atol )
181
148
logger .info (f"All good, model saved at: { args .output .as_posix ()} " )
0 commit comments