Skip to content

Commit e2d4895

Browse files
committed
Refactoring to package transformers.onnx
1 parent 29582cb commit e2d4895

15 files changed

+808
-567
lines changed

src/transformers/convert_graph_to_onnx_v2.py

Lines changed: 0 additions & 557 deletions
This file was deleted.

src/transformers/models/albert/configuration_albert.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
# limitations under the License.
1616
""" ALBERT model configuration """
1717

18-
from ...configuration_utils import PretrainedConfig, OnnxConfig, OnnxVariable
18+
from ...configuration_utils import PretrainedConfig
19+
from ...onnx import OnnxConfig, OnnxVariable
1920

2021
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
2122
"albert-base-v1": "https://huggingface.co/albert-base-v1/resolve/main/config.json",

src/transformers/models/bart/configuration_bart.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
""" BART model configuration """
1616
import warnings
1717

18-
from ...configuration_utils import PretrainedConfig, OnnxConfig, OnnxVariable
18+
from ...configuration_utils import PretrainedConfig
19+
from ...onnx import OnnxConfig, OnnxVariable
1920
from ...utils import logging
2021

2122

src/transformers/models/bert/configuration_bert.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
""" BERT model configuration """
17-
18-
from ...configuration_utils import PretrainedConfig, OnnxConfig, OnnxVariable
17+
from ...onnx import OnnxConfig, OnnxVariable
1918
from ...utils import logging
2019

2120

src/transformers/models/gpt2/configuration_gpt2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,8 @@
1515
# limitations under the License.
1616
""" OpenAI GPT-2 configuration """
1717

18-
from ...configuration_utils import PretrainedConfig, OnnxConfig, OnnxVariable
18+
from ...configuration_utils import PretrainedConfig
19+
from ...onnx import OnnxConfig, OnnxVariable
1920
from ...utils import logging
2021

2122

src/transformers/models/longformer/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
_import_structure = {
25-
"configuration_longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LongformerConfig"],
25+
"configuration_longformer": ["LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP", "LONGFORMER_ONNX_CONFIG", "LongformerConfig"],
2626
"tokenization_longformer": ["LongformerTokenizer"],
2727
}
2828

@@ -57,7 +57,7 @@
5757

5858

5959
if TYPE_CHECKING:
60-
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
60+
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LONGFORMER_ONNX_CONFIG, LongformerConfig
6161
from .tokenization_longformer import LongformerTokenizer
6262

6363
if is_tokenizers_available():

src/transformers/models/longformer/configuration_longformer.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from typing import List, Union
1818

19+
from ...onnx import OnnxConfig, OnnxVariable
1920
from ...utils import logging
2021
from ..roberta.configuration_roberta import RobertaConfig
2122

@@ -69,3 +70,33 @@ class LongformerConfig(RobertaConfig):
6970
def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id: int = 2, **kwargs):
7071
super().__init__(sep_token_id=sep_token_id, **kwargs)
7172
self.attention_window = attention_window
73+
74+
75+
LONGFORMER_ONNX_CONFIG = OnnxConfig(
76+
inputs=[
77+
OnnxVariable("input_ids", {0: "batch", 1: "sequence"}, repeated=1, value=None),
78+
OnnxVariable("attention_mask", {0: "batch", 1: "sequence"}, repeated=1, value=None),
79+
],
80+
outputs=[
81+
OnnxVariable("last_hidden_state", {0: "batch", 1: "sequence"}, repeated=1, value=None),
82+
OnnxVariable("pooler_output", {0: "batch"}, repeated=1, value=None),
83+
],
84+
runtime_config_overrides=None,
85+
use_external_data_format=False,
86+
minimum_required_onnx_opset=12,
87+
optimizer="bert",
88+
optimizer_features={
89+
"enable_gelu": True,
90+
"enable_layer_norm": True,
91+
"enable_attention": True,
92+
"enable_skip_layer_norm": True,
93+
"enable_embed_layer_norm": True,
94+
"enable_bias_skip_layer_norm": True,
95+
"enable_bias_gelu": True,
96+
"enable_gelu_approximation": False,
97+
},
98+
optimizer_additional_args={
99+
"num_heads": "$config.num_attention_heads",
100+
"hidden_size": "$config.hidden_size"
101+
}
102+
)

src/transformers/models/roberta/configuration_roberta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
""" RoBERTa configuration """
17-
from ... import OnnxConfig, OnnxVariable
17+
from ...onnx import OnnxConfig, OnnxVariable
1818
from ...utils import logging
1919
from ..bert.configuration_bert import BertConfig
2020

src/transformers/models/t5/configuration_t5.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
# limitations under the License.
1515
""" T5 model configuration """
1616

17-
from ...configuration_utils import PretrainedConfig, OnnxConfig, OnnxVariable
17+
from ...configuration_utils import PretrainedConfig
18+
from ...onnx import OnnxConfig, OnnxVariable
1819
from ...utils import logging
1920

2021

src/transformers/models/xlm_roberta/configuration_xlm_roberta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
""" XLM-RoBERTa configuration """
17-
from ... import OnnxConfig, OnnxVariable
17+
from ...onnx import OnnxConfig, OnnxVariable
1818
from ...utils import logging
1919
from ..roberta.configuration_roberta import RobertaConfig
2020

src/transformers/onnx/__init__.py

Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
from argparse import ArgumentParser
2+
from pathlib import Path
3+
from typing import Union, Tuple
4+
5+
from onnxruntime import GraphOptimizationLevel
6+
7+
from transformers import is_torch_available, is_tf_available, AutoTokenizer
8+
from transformers.models.albert import ALBERT_ONNX_CONFIG
9+
from transformers.models.bart import BART_ONNX_CONFIG, BART_ONNX_CONFIG_WITH_PAST
10+
from transformers.models.bert import BERT_ONNX_CONFIG
11+
from transformers.models.distilbert import DISTILBERT_ONNX_CONFIG, DISTILBERT_TOKEN_CLASSIFICATION_ONNX_CONFIG
12+
from transformers.models.gpt2 import GPT2_ONNX_CONFIG, GPT2_ONNX_CONFIG_WITH_PAST
13+
# from transformers.models.longformer import LONGFORMER_ONNX_CONFIG
14+
from transformers.models.roberta import ROBERTA_ONNX_CONFIG
15+
from transformers.models.t5 import T5_ONNX_CONFIG
16+
from transformers.models.xlm_roberta import XLM_ROBERTA_ONNX_CONFIG
17+
18+
# from .config import OnnxConfig, OnnxVariable
19+
# from .convert import convert_pytorch, ensure_model_and_config_inputs_match, optimize, validate_model_outputs
20+
# from .interpolate import evaluate_expr_to_int, expand_repeated_onnx_variables, interpolate_expression, \
21+
# insert_additional_onnx_value_within_inputs
22+
# from .utils import flatten_output_collection_property, generate_identified_filename
23+
24+
25+
# Set of frameworks we can export from
26+
FRAMEWORK_NAME_PT = "pytorch"
27+
FRAMEWORK_NAME_TF = "tensorflow"
28+
FRAMEWORK_CHOICES = {FRAMEWORK_NAME_PT, FRAMEWORK_NAME_PT}
29+
30+
if is_torch_available():
31+
from transformers import AutoModel, AutoModelForTokenClassification, PreTrainedModel
32+
FEATURES_TO_AUTOMODELS = {
33+
"default": AutoModel,
34+
"with_path": AutoModel,
35+
"token_classification": AutoModelForTokenClassification
36+
}
37+
38+
if is_tf_available():
39+
from transformers import TFAutoModel, TFAutoModelForTokenClassification, TFPreTrainedModel
40+
FEATURES_TO_TF_AUTOMODELS = {
41+
"default": TFAutoModel,
42+
"with_path": TFAutoModel,
43+
"token_classification": TFAutoModelForTokenClassification
44+
}
45+
46+
# Set of model topologies we support
47+
SUPPORTED_MODEL_KIND = {
48+
"albert": {
49+
"default": ALBERT_ONNX_CONFIG
50+
},
51+
"bart": {
52+
"default": BART_ONNX_CONFIG,
53+
"with_past": BART_ONNX_CONFIG_WITH_PAST
54+
},
55+
"bert": {
56+
"default": BERT_ONNX_CONFIG
57+
},
58+
"distilbert": {
59+
"default": DISTILBERT_ONNX_CONFIG,
60+
"token_classification": DISTILBERT_TOKEN_CLASSIFICATION_ONNX_CONFIG
61+
},
62+
"gpt2": {
63+
"default": GPT2_ONNX_CONFIG,
64+
"with_past": GPT2_ONNX_CONFIG_WITH_PAST
65+
},
66+
# "longformer": {
67+
# "default": LONGFORMER_ONNX_CONFIG,
68+
# },
69+
"roberta": {
70+
"default": ROBERTA_ONNX_CONFIG,
71+
},
72+
"t5": {
73+
"default": T5_ONNX_CONFIG,
74+
},
75+
"xlm-roberta": {
76+
"default": XLM_ROBERTA_ONNX_CONFIG
77+
}
78+
}
79+
80+
# ONNX Runtime optimization levels for humans
81+
ONNX_OPTIMIZATION_LEVELS = {
82+
"disabled": GraphOptimizationLevel.ORT_DISABLE_ALL,
83+
"default": GraphOptimizationLevel.ORT_ENABLE_BASIC,
84+
"extended": GraphOptimizationLevel.ORT_ENABLE_EXTENDED,
85+
"all": GraphOptimizationLevel.ORT_ENABLE_ALL
86+
}
87+
88+
89+
def get_model_from_framework_and_features(framework: str, features: str, model: str):
90+
"""
91+
Attempt to retrieve a model from a model's name and the features to be enabled.
92+
Args:
93+
framework: The framework we are targeting
94+
features: The features required
95+
model: The name of the model to export
96+
97+
Returns:
98+
99+
"""
100+
if framework == FRAMEWORK_NAME_PT:
101+
if features not in FEATURES_TO_AUTOMODELS:
102+
raise KeyError(
103+
f"Unknown feature: {features}."
104+
f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}"
105+
)
106+
107+
return FEATURES_TO_AUTOMODELS[features].from_pretrained(model)
108+
elif framework == FRAMEWORK_NAME_TF:
109+
if features not in FEATURES_TO_TF_AUTOMODELS:
110+
raise KeyError(
111+
f"Unknown feature: {features}."
112+
f"Possible values are {list(FEATURES_TO_AUTOMODELS.values())}"
113+
)
114+
return FEATURES_TO_TF_AUTOMODELS[features].from_pretrained(model)
115+
else:
116+
raise ValueError(f"Unknown framework: {framework}")
117+
118+
119+
def check_supported_model_or_raise(
120+
model: Union[PreTrainedModel, TFPreTrainedModel],
121+
features: str = "default"
122+
) -> Tuple[str, 'OnnxConfig']:
123+
"""
124+
Check whether or not the model has the requested features
125+
Args:
126+
model: The model to export
127+
features: The name of the features to check if they are avaiable
128+
129+
Returns:
130+
(str) The type of the model
131+
(OnnxConfig) The OnnxConfig instance holding the model export properties
132+
133+
"""
134+
if model.config.model_type not in SUPPORTED_MODEL_KIND:
135+
raise KeyError(
136+
f"{model.config.model_type} ({model.name}) is not supported yet. "
137+
f"Only {SUPPORTED_MODEL_KIND} are supported. "
138+
f"If you want to support ({model.config.model_type}) please propose a PR or open up an issue."
139+
)
140+
141+
# Look for the features
142+
model_features = SUPPORTED_MODEL_KIND[model.config.model_type]
143+
if features not in model_features:
144+
raise ValueError(
145+
f"{model.config.model_type} doesn't support features {features}. "
146+
f"Supported values are: {list(model_features.keys())}"
147+
)
148+
149+
return model.config.model_type, SUPPORTED_MODEL_KIND[model.config.model_type][features]
150+
151+
152+
if __name__ == '__main__':
153+
154+
parser = ArgumentParser("Hugging Face ONNX Exporter tool")
155+
parser.add_argument("-m", "--model", type=str, required=True, help="Model's name of path on disk to load.")
156+
parser.add_argument("-f", "--framework", choices=FRAMEWORK_CHOICES, required=True, help=f"Framework to use when exporting. Possible values are: {FRAMEWORK_CHOICES}")
157+
parser.add_argument("--features", choices=["default", "with_past", "token_classification"], default="default", help="Export the model with some additional features.")
158+
parser.add_argument("--opset", type=int, default=12, help="ONNX opset version to export the model with (default 12).")
159+
parser.add_argument("--optimize", action="store_true", help="Flag indicating if we should try to optimize the model.")
160+
parser.add_argument("--use-gpu", action="store_true", help="Flag indicating if we should try to optimize the model for GPU inference.")
161+
parser.add_argument("--optimization-level", choices=ONNX_OPTIMIZATION_LEVELS.keys(), default="disabled", help="Flag indicating if we should try to optimize the model.")
162+
parser.add_argument("--atol", type=float, default=1e-4, help="Absolute difference tolerence when validating the model.")
163+
parser.add_argument("output", type=Path, help="Path indicating where to store generated ONNX model.")
164+
165+
# Retrieve CLI arguments
166+
args = parser.parse_args()
167+
args.output = args.output if args.output.is_file() else args.output.joinpath("model.onnx")
168+
169+
if not args.output.parent.exists():
170+
args.output.parent.mkdir(parents=True)
171+
172+
print(f"About to export model: {args.model} using framework: {args.framework}")
173+
174+
# Allocate the model
175+
tokenizer = AutoTokenizer.from_pretrained(args.model)
176+
model = get_model_from_framework_and_features(args.framework, args.features, args.model)
177+
model_kind, onnx_config = check_supported_model_or_raise(model, features=args.features)
178+
179+
# Override model's config if needed
180+
if onnx_config.runtime_config_overrides is not None:
181+
print("Overriding model's config values:")
182+
for config_key, config_value in onnx_config.runtime_config_overrides.items():
183+
print(f"\t- {config_key} => {config_value}")
184+
setattr(model.config, config_key, config_value)
185+
186+
# Ensure the requested opset is sufficient
187+
if args.opset < onnx_config.minimum_required_onnx_opset:
188+
raise ValueError(
189+
f"Opset {args.opset} is not sufficient to export {model_kind}. "
190+
f"At least {onnx_config.minimum_required_onnx_opset} is required."
191+
)
192+
193+
if args.framework == FRAMEWORK_NAME_PT:
194+
onnx_inputs, onnx_outputs = convert_pytorch(tokenizer, model, onnx_config, args.opset, args.output)
195+
else:
196+
raise NotImplementedError()
197+
198+
validate_model_outputs(tokenizer, model, args.output, onnx_inputs, onnx_outputs, args.atol)
199+
print(f"All good, model saved at: {args.output.as_posix()}")
200+
201+
if args.optimize and args.optimization_level != "disabled":
202+
print(f"About to optimize model with optimization_level: {args.optimization_level}")
203+
204+
args.opt_model_output = generate_identified_filename(args.output, f"_optimized_{args.optimization_level}")
205+
args.optimization_level = ONNX_OPTIMIZATION_LEVELS[args.optimization_level]
206+
optimize(args.output, model, onnx_config, args.optimization_level, args.use_gpu, args.opt_model_output)
207+
208+
if not args.use_gpu:
209+
validate_model_outputs(tokenizer, model, args.opt_model_output, onnx_inputs, onnx_outputs, args.atol)
210+
else:
211+
print(
212+
"Validating model targeting GPU is not supported yet. "
213+
"Please, fill an issue or submit a PR if it's something you need."
214+
)
215+
216+
print(f"Optimized model saved at: {args.opt_model_output.as_posix()}")

src/transformers/onnx/config.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
from dataclasses import dataclass
2+
from typing import List, Optional, Dict, Any, Union, NamedTuple
3+
4+
OnnxVariable = NamedTuple("OnnxVariable", [
5+
("name", str),
6+
("axes", Dict[int, str]),
7+
("repeated", Union[int, str]),
8+
("value", Optional[List[int]])
9+
])
10+
11+
12+
@dataclass
13+
class OnnxConfig:
14+
"""
15+
Base class for ONNX exportable model describing metadata on how to export the model
16+
through the ONNX format.
17+
"""
18+
19+
# Input mapping of the form "input_name": {axis_id: "axis_name"}
20+
# example: {"input_ids": {0: "batch", 1: "sequence"}}
21+
# We use a list because the ordering of the items is VERY important
22+
inputs: List[OnnxVariable]
23+
24+
# Output mapping of the form "output_name": {axis_id: "axis_name"}
25+
# example: {"last_hidden_layer": {0: "batch", 1: "sequence"}}
26+
# We use a list because the ordering of the items is VERY important
27+
outputs: List[OnnxVariable]
28+
29+
# Define all the configuration keys we need to override before forwarding through the model
30+
runtime_config_overrides: Optional[Dict[str, Any]]
31+
32+
# Does the model requires using external data format (i.e. model size > 2Gb)
33+
use_external_data_format: bool
34+
35+
# Minimum required ONNX opset
36+
minimum_required_onnx_opset: int
37+
38+
# ONNXRuntime provides model specific optimizer for some topologies
39+
# This one indicate which provider (if any) to use
40+
optimizer: Optional[str]
41+
42+
# If optimizer is present, this set indicates which features to enable/disable when optimizing
43+
optimizer_features: Optional[Dict[str, bool]]
44+
45+
# Optimizer parameters which can only be known at runtime
46+
optimizer_additional_args: Optional[Dict[str, Union[int, str]]]

0 commit comments

Comments
 (0)