Skip to content

Commit 01da1b2

Browse files
committed
Remove framework argument
1 parent ea01db4 commit 01da1b2

File tree

1 file changed

+8
-41
lines changed

1 file changed

+8
-41
lines changed

src/transformers/onnx/__main__.py

Lines changed: 8 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -27,29 +27,18 @@
2727
from transformers.models.t5 import T5OnnxConfig
2828
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
2929

30-
from .. import is_tf_available, is_torch_available
30+
from .. import is_torch_available
3131
from ..utils import logging
3232
from .convert import convert_pytorch, validate_model_outputs
3333

3434

35-
# Set of frameworks we can export from
36-
FRAMEWORK_NAME_PT = "pytorch"
37-
FRAMEWORK_NAME_TF = "tensorflow"
38-
FRAMEWORK_CHOICES = {FRAMEWORK_NAME_PT, FRAMEWORK_NAME_PT}
39-
4035
if is_torch_available():
4136
from transformers import AutoModel, PreTrainedModel
4237

4338
FEATURES_TO_AUTOMODELS = {
4439
"default": AutoModel,
4540
}
4641

47-
if is_tf_available():
48-
from transformers import TFAutoModel, TFPreTrainedModel
49-
50-
FEATURES_TO_TF_AUTOMODELS = {
51-
"default": TFAutoModel,
52-
}
5342

5443
# Set of model topologies we support associated to the features supported by each topology and the factory
5544
SUPPORTED_MODEL_KIND = {
@@ -65,33 +54,23 @@
6554
}
6655

6756

68-
def get_model_from_framework_and_features(framework: str, features: str, model: str):
57+
def get_model_from_features(features: str, model: str):
6958
"""
7059
Attempt to retrieve a model from a model's name and the features to be enabled.
7160
7261
Args:
73-
framework: The framework we are targeting
7462
features: The features required
7563
model: The name of the model to export
7664
7765
Returns:
7866
7967
"""
80-
if framework == FRAMEWORK_NAME_PT:
81-
if features not in FEATURES_TO_AUTOMODELS:
82-
raise KeyError(
83-
f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}"
68+
if features not in FEATURES_TO_AUTOMODELS:
69+
raise KeyError(
70+
f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}"
8471
)
8572

86-
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
87-
elif framework == FRAMEWORK_NAME_TF:
88-
if features not in FEATURES_TO_TF_AUTOMODELS:
89-
raise KeyError(
90-
f"Unknown feature: {features}." f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}"
91-
)
92-
return FEATURES_TO_TF_AUTOMODELS[features].from_pretrained(model)
93-
else:
94-
raise ValueError(f"Unknown framework: {framework}")
73+
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
9574

9675

9776
def check_supported_model_or_raise(
@@ -129,13 +108,6 @@ def check_supported_model_or_raise(
129108
def main():
130109
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
131110
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
132-
parser.add_argument(
133-
"-f",
134-
"--framework",
135-
choices=FRAMEWORK_CHOICES,
136-
required=True,
137-
help=f"Framework to use when exporting. Possible values are: {FRAMEWORK_CHOICES}",
138-
)
139111
parser.add_argument(
140112
"--features",
141113
choices=["default"],
@@ -157,11 +129,9 @@ def main():
157129
if not args.output.parent.exists():
158130
args.output.parent.mkdir(parents=True)
159131

160-
logger.info(f"About to export model: {args.model} using framework: {args.framework}")
161-
162132
# Allocate the model
163133
tokenizer = AutoTokenizer.from_pretrained(args.model)
164-
model = get_model_from_framework_and_features(args.framework, args.features, args.model)
134+
model = get_model_from_features(args.features, args.model)
165135
model_kind, model_onnx_config = check_supported_model_or_raise(model, features=args.features)
166136
onnx_config = model_onnx_config(model.config)
167137

@@ -172,10 +142,7 @@ def main():
172142
f"At least {onnx_config.default_onnx_opset} is required."
173143
)
174144

175-
if args.framework == FRAMEWORK_NAME_PT:
176-
onnx_inputs, onnx_outputs = convert_pytorch(tokenizer, model, onnx_config, args.opset, args.output)
177-
else:
178-
raise NotImplementedError()
145+
onnx_inputs, onnx_outputs = convert_pytorch(tokenizer, model, onnx_config, args.opset, args.output)
179146

180147
validate_model_outputs(onnx_config, tokenizer, model, args.output, onnx_outputs, args.atol)
181148
logger.info(f"All good, model saved at: {args.output.as_posix()}")

0 commit comments

Comments
 (0)