Skip to content

Commit c30eff6

Browse files
committed
Use @Narsil's suggestion to forward the model's configuration to the ONNXConfig to avoid interpolation.
1 parent 17a9081 commit c30eff6

25 files changed

+769
-673
lines changed

src/transformers/models/albert/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
_import_structure = {
31-
"configuration_albert": ["ALBERT_ONNX_CONFIG", "ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig"],
31+
"configuration_albert": ["ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "AlbertConfig", "AlbertOnnxConfig"],
3232
}
3333

3434
if is_sentencepiece_available():
@@ -67,7 +67,7 @@
6767

6868

6969
if TYPE_CHECKING:
70-
from .configuration_albert import ALBERT_ONNX_CONFIG, ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig
70+
from .configuration_albert import ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP, AlbertConfig, AlbertOnnxConfig
7171

7272
if is_sentencepiece_available():
7373
from .tokenization_albert import AlbertTokenizer

src/transformers/models/albert/configuration_albert.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
""" ALBERT model configuration """
17+
from typing import Mapping, Optional, Any
1718

1819
from ...configuration_utils import PretrainedConfig
19-
from ...onnx import OnnxConfig, OnnxVariable
20-
20+
from ...onnx import OnnxConfig, DEFAULT_BERT_OPTIMIZER_FEATURES
2121

2222
ALBERT_PRETRAINED_CONFIG_ARCHIVE_MAP = {
2323
"albert-base-v1": "https://huggingface.co/albert-base-v1/resolve/main/config.json",
@@ -154,29 +154,36 @@ def __init__(
154154
self.position_embedding_type = position_embedding_type
155155

156156

157-
ALBERT_ONNX_CONFIG = OnnxConfig(
158-
inputs=[
159-
OnnxVariable("input_ids", {0: "batch", 1: "sequence"}, repeated=1, value=None),
160-
OnnxVariable("attention_mask", {0: "batch", 1: "sequence"}, repeated=1, value=None),
161-
OnnxVariable("token_type_ids", {0: "batch", 1: "sequence"}, repeated=1, value=None),
162-
],
163-
outputs=[
164-
OnnxVariable("last_hidden_state", {0: "batch", 1: "sequence"}, repeated=1, value=None),
165-
OnnxVariable("pooler_output", {0: "batch"}, repeated=1, value=None),
166-
],
167-
runtime_config_overrides=None,
168-
use_external_data_format=False,
169-
minimum_required_onnx_opset=12,
170-
optimizer="bert",
171-
optimizer_features={
172-
"enable_gelu": True,
173-
"enable_layer_norm": True,
174-
"enable_attention": True,
175-
"enable_skip_layer_norm": True,
176-
"enable_embed_layer_norm": True,
177-
"enable_bias_skip_layer_norm": True,
178-
"enable_bias_gelu": True,
179-
"enable_gelu_approximation": False,
180-
},
181-
optimizer_additional_args={"num_heads": "$config.num_attention_heads", "hidden_size": "$config.hidden_size"},
182-
)
157+
# Copied from transformers.models.bert.configuration_bert.BertOnnxConfig with Bert->Albert
158+
class AlbertOnnxConfig(OnnxConfig):
159+
160+
@property
161+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
162+
return {
163+
"input_ids": {0: "batch", 1: "sequence"},
164+
"attention_mask": {0: "batch", 1: "sequence"},
165+
"token_type_ids": {0: "batch", 1: "sequence"},
166+
}
167+
168+
@property
169+
def outputs(self) -> Mapping[str, Mapping[int, str]]:
170+
return {
171+
"last_hidden_state": {0: "batch", 1: "sequence"},
172+
"pooler_output": {0: "batch"}
173+
}
174+
175+
@property
176+
def optimizer(self) -> Optional[str]:
177+
return "bert"
178+
179+
@property
180+
def optimizer_features(self) -> Optional[Mapping[str, bool]]:
181+
return DEFAULT_BERT_OPTIMIZER_FEATURES
182+
183+
@property
184+
def optimizer_additional_args(self) -> Optional[Mapping[str, Any]]:
185+
return {
186+
"num_heads": self._config.num_attention_heads,
187+
"hidden_size": self._config.hidden_size
188+
}
189+

src/transformers/models/bart/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,9 @@
2828

2929
_import_structure = {
3030
"configuration_bart": [
31-
"BART_ONNX_CONFIG",
32-
"BART_ONNX_CONFIG_WITH_PAST",
3331
"BART_PRETRAINED_CONFIG_ARCHIVE_MAP",
3432
"BartConfig",
33+
"BartOnnxConfig"
3534
],
3635
"tokenization_bart": ["BartTokenizer"],
3736
}
@@ -65,10 +64,9 @@
6564

6665
if TYPE_CHECKING:
6766
from .configuration_bart import (
68-
BART_ONNX_CONFIG,
69-
BART_ONNX_CONFIG_WITH_PAST,
7067
BART_PRETRAINED_CONFIG_ARCHIVE_MAP,
7168
BartConfig,
69+
BartOnnxConfig
7270
)
7371
from .tokenization_bart import BartTokenizer
7472

src/transformers/models/bart/configuration_bart.py

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
# limitations under the License.
1515
""" BART model configuration """
1616
import warnings
17+
from typing import Mapping, Optional, Any
1718

1819
from ...configuration_utils import PretrainedConfig
19-
from ...onnx import OnnxConfig, OnnxVariable
20+
from ...onnx import OnnxConfigWithPast, DEFAULT_BERT_OPTIMIZER_FEATURES
2021
from ...utils import logging
2122

2223

@@ -189,37 +190,40 @@ def hidden_size(self) -> int:
189190
return self.d_model
190191

191192

192-
BART_ONNX_CONFIG = OnnxConfig(
193-
inputs=[
194-
OnnxVariable("input_ids", {0: "batch", 1: "sequence"}, repeated=1, value=None),
195-
OnnxVariable("attention_mask", {0: "batch", 1: "sequence"}, repeated=1, value=None),
196-
],
197-
outputs=[
198-
OnnxVariable("last_hidden_state", {0: "batch", 1: "sequence"}, repeated=1, value=None),
199-
OnnxVariable("encoder_last_hidden_state", {0: "batch", 1: "sequence"}, repeated=1, value=None),
200-
],
201-
runtime_config_overrides={"use_cache": False},
202-
use_external_data_format=False,
203-
minimum_required_onnx_opset=11,
204-
optimizer="bert",
205-
optimizer_features=None,
206-
optimizer_additional_args={"num_heads": "$config.decoder_attention_heads", "hidden_size": "$config.d_model"},
207-
)
208-
209-
BART_ONNX_CONFIG_WITH_PAST = OnnxConfig(
210-
inputs=[
211-
OnnxVariable("input_ids", {0: "batch", 1: "sequence"}, repeated=1, value=None),
212-
OnnxVariable("attention_mask", {0: "batch", 1: "sequence"}, repeated=1, value=None),
213-
],
214-
outputs=[
215-
OnnxVariable("last_hidden_state", {0: "batch", 1: "sequence"}, repeated=1, value=None),
216-
OnnxVariable("past_keys", {0: "batch", 2: "sequence"}, repeated="$config.decoder_layers * 4", value=None),
217-
OnnxVariable("encoder_last_hidden_state", {0: "batch", 1: "sequence"}, repeated=1, value=None),
218-
],
219-
runtime_config_overrides={"use_cache": True},
220-
use_external_data_format=False,
221-
minimum_required_onnx_opset=11,
222-
optimizer="bert",
223-
optimizer_features=None,
224-
optimizer_additional_args={"num_heads": "$config.decoder_attention_heads", "hidden_size": "$config.d_model"},
225-
)
193+
class BartOnnxConfig(OnnxConfigWithPast):
194+
195+
@property
196+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
197+
return {
198+
"input_ids": {0: "batch", 1: "sequence"},
199+
"attention_mask": {0: "batch", 1: "sequence"},
200+
}
201+
202+
@property
203+
def outputs(self) -> Mapping[str, Mapping[int, str]]:
204+
if self.use_past:
205+
return {
206+
"last_hidden_state": {0: "batch", 1: "sequence"},
207+
"past_keys": {0: "batch", 2: "sequence"},
208+
"encoder_last_hidden_state": {0: "batch", 1: "sequence"},
209+
}
210+
else:
211+
return {
212+
"last_hidden_state": {0: "batch", 1: "sequence"},
213+
"encoder_last_hidden_state": {0: "batch", 1: "sequence"},
214+
}
215+
216+
@property
217+
def optimizer(self) -> Optional[str]:
218+
return "bert"
219+
220+
@property
221+
def optimizer_features(self) -> Optional[Mapping[str, bool]]:
222+
return DEFAULT_BERT_OPTIMIZER_FEATURES
223+
224+
@property
225+
def optimizer_additional_args(self) -> Optional[Mapping[str, Any]]:
226+
return {
227+
"num_heads": self._config.decoder_attention_heads,
228+
"hidden_size": self._config.d_model
229+
}

src/transformers/models/bert/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929

3030
_import_structure = {
31-
"configuration_bert": ["BERT_ONNX_CONFIG", "BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BertConfig"],
31+
"configuration_bert": ["BERT_PRETRAINED_CONFIG_ARCHIVE_MAP", "BertConfig", "BertOnnxConfig"],
3232
"tokenization_bert": ["BasicTokenizer", "BertTokenizer", "WordpieceTokenizer"],
3333
}
3434

@@ -83,7 +83,7 @@
8383
]
8484

8585
if TYPE_CHECKING:
86-
from .configuration_bert import BERT_ONNX_CONFIG, BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig
86+
from .configuration_bert import BERT_PRETRAINED_CONFIG_ARCHIVE_MAP, BertConfig, BertOnnxConfig
8787
from .tokenization_bert import BasicTokenizer, BertTokenizer, WordpieceTokenizer
8888

8989
if is_tokenizers_available():

src/transformers/models/bert/configuration_bert.py

Lines changed: 32 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616
""" BERT model configuration """
17+
from typing import Mapping, Optional, Any
1718

1819
from ...configuration_utils import PretrainedConfig
19-
from ...onnx.config import OnnxConfig, OnnxVariable
20+
from ...onnx import OnnxConfig, DEFAULT_BERT_OPTIMIZER_FEATURES
2021
from ...utils import logging
2122

2223

@@ -157,29 +158,33 @@ def __init__(
157158
self.use_cache = use_cache
158159

159160

160-
BERT_ONNX_CONFIG = OnnxConfig(
161-
inputs=[
162-
OnnxVariable("input_ids", {0: "batch", 1: "sequence"}, repeated=1, value=None),
163-
OnnxVariable("attention_mask", {0: "batch", 1: "sequence"}, repeated=1, value=None),
164-
OnnxVariable("token_type_ids", {0: "batch", 1: "sequence"}, repeated=1, value=None),
165-
],
166-
outputs=[
167-
OnnxVariable("last_hidden_state", {0: "batch", 1: "sequence"}, repeated=1, value=None),
168-
OnnxVariable("pooler_output", {0: "batch"}, repeated=1, value=None),
169-
],
170-
runtime_config_overrides=None,
171-
use_external_data_format=False,
172-
minimum_required_onnx_opset=12,
173-
optimizer="bert",
174-
optimizer_features={
175-
"enable_gelu": True,
176-
"enable_layer_norm": True,
177-
"enable_attention": True,
178-
"enable_skip_layer_norm": True,
179-
"enable_embed_layer_norm": True,
180-
"enable_bias_skip_layer_norm": True,
181-
"enable_bias_gelu": True,
182-
"enable_gelu_approximation": False,
183-
},
184-
optimizer_additional_args={"num_heads": "$config.num_attention_heads", "hidden_size": "$config.hidden_size"},
185-
)
161+
class BertOnnxConfig(OnnxConfig):
162+
@property
163+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
164+
return {
165+
"input_ids": {0: "batch", 1: "sequence"},
166+
"attention_mask": {0: "batch", 1: "sequence"},
167+
"token_type_ids": {0: "batch", 1: "sequence"},
168+
}
169+
170+
@property
171+
def outputs(self) -> Mapping[str, Mapping[int, str]]:
172+
return {
173+
"last_hidden_state": {0: "batch", 1: "sequence"},
174+
"pooler_output": {0: "batch"}
175+
}
176+
177+
@property
178+
def optimizer(self) -> Optional[str]:
179+
return "bert"
180+
181+
@property
182+
def optimizer_features(self) -> Optional[Mapping[str, bool]]:
183+
return DEFAULT_BERT_OPTIMIZER_FEATURES
184+
185+
@property
186+
def optimizer_additional_args(self) -> Optional[Mapping[str, Any]]:
187+
return {
188+
"num_heads": self._config.num_attention_heads,
189+
"hidden_size": self._config.hidden_size
190+
}

src/transformers/models/distilbert/__init__.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,9 @@
2323

2424
_import_structure = {
2525
"configuration_distilbert": [
26-
"DISTILBERT_ONNX_CONFIG",
2726
"DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP",
28-
"DISTILBERT_TOKEN_CLASSIFICATION_ONNX_CONFIG",
2927
"DistilBertConfig",
28+
"DistilBertOnnxConfig"
3029
],
3130
"tokenization_distilbert": ["DistilBertTokenizer"],
3231
}
@@ -62,10 +61,9 @@
6261

6362
if TYPE_CHECKING:
6463
from .configuration_distilbert import (
65-
DISTILBERT_ONNX_CONFIG,
6664
DISTILBERT_PRETRAINED_CONFIG_ARCHIVE_MAP,
67-
DISTILBERT_TOKEN_CLASSIFICATION_ONNX_CONFIG,
6865
DistilBertConfig,
66+
DistilBertOnnxConfig
6967
)
7068
from .tokenization_distilbert import DistilBertTokenizer
7169

0 commit comments

Comments
 (0)