Skip to content

Commit ca0ad12

Browse files
committed
LongFormer ONNX config.
1 parent 49916c9 commit ca0ad12

File tree

4 files changed

+44
-28
lines changed

4 files changed

+44
-28
lines changed

src/transformers/models/longformer/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"configuration_longformer": [
2626
"LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP",
2727
"LongformerConfig",
28+
"LongformerOnnxConfig"
2829
],
2930
"tokenization_longformer": ["LongformerTokenizer"],
3031
}
@@ -61,6 +62,11 @@
6162

6263
if TYPE_CHECKING:
6364
from .configuration_longformer import LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP, LongformerConfig
65+
from .configuration_longformer import (
66+
LONGFORMER_PRETRAINED_CONFIG_ARCHIVE_MAP,
67+
LongformerConfig,
68+
LongformerOnnxConfig
69+
)
6470
from .tokenization_longformer import LongformerTokenizer
6571

6672
if is_tokenizers_available():

src/transformers/models/longformer/configuration_longformer.py

Lines changed: 34 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
# limitations under the License.
1515
""" Longformer configuration """
1616

17-
from typing import List, Union
17+
from typing import List, Union, Mapping, Optional, Any
1818

19+
from ...onnx import OnnxConfig, DEFAULT_BERT_OPTIMIZER_FEATURES
1920
from ...utils import logging
2021
from ..roberta.configuration_roberta import RobertaConfig
2122

@@ -71,28 +72,35 @@ def __init__(self, attention_window: Union[List[int], int] = 512, sep_token_id:
7172
self.attention_window = attention_window
7273

7374

74-
# LONGFORMER_ONNX_CONFIG = OnnxConfig(
75-
# inputs=[
76-
# OnnxVariable("input_ids", {0: "batch", 1: "sequence"}, repeated=1, value=None),
77-
# OnnxVariable("attention_mask", {0: "batch", 1: "sequence"}, repeated=1, value=None),
78-
# ],
79-
# outputs=[
80-
# OnnxVariable("last_hidden_state", {0: "batch", 1: "sequence"}, repeated=1, value=None),
81-
# OnnxVariable("pooler_output", {0: "batch"}, repeated=1, value=None),
82-
# ],
83-
# runtime_config_overrides=None,
84-
# use_external_data_format=False,
85-
# minimum_required_onnx_opset=12,
86-
# optimizer="bert",
87-
# optimizer_features={
88-
# "enable_gelu": True,
89-
# "enable_layer_norm": True,
90-
# "enable_attention": True,
91-
# "enable_skip_layer_norm": True,
92-
# "enable_embed_layer_norm": True,
93-
# "enable_bias_skip_layer_norm": True,
94-
# "enable_bias_gelu": True,
95-
# "enable_gelu_approximation": False,
96-
# },
97-
# optimizer_additional_args={"num_heads": "$config.num_attention_heads", "hidden_size": "$config.hidden_size"},
98-
# )
75+
class LongformerOnnxConfig(OnnxConfig):
76+
77+
@property
78+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
79+
return {
80+
"input_ids": {0: "batch", 1: "sequence"},
81+
"attention_mask": {0: "batch", 1: "sequence"},
82+
}
83+
84+
@property
85+
def outputs(self) -> Mapping[str, Mapping[int, str]]:
86+
return {
87+
"last_hidden_state": {0: "batch", 1: "sequence"},
88+
"pooler_output": {0: "batch"}
89+
}
90+
91+
@property
92+
def optimizer(self) -> Optional[str]:
93+
return "bert"
94+
95+
@property
96+
def optimizer_features(self) -> Optional[Mapping[str, bool]]:
97+
return DEFAULT_BERT_OPTIMIZER_FEATURES
98+
99+
@property
100+
def optimizer_additional_args(self) -> Optional[Mapping[str, Any]]:
101+
return {
102+
"num_heads": self._config.num_attention_heads,
103+
"hidden_size": self._config.hidden_size
104+
}
105+
106+

src/transformers/onnx/__main__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from transformers.models.bert import BertOnnxConfig
2525
from transformers.models.distilbert import DistilBertOnnxConfig
2626
from transformers.models.gpt2 import GPT2OnnxConfig
27+
from transformers.models.longformer import LongformerOnnxConfig
2728
from transformers.models.roberta import RobertaOnnxConfig
2829
from transformers.models.t5 import T5OnnxConfig
2930
from transformers.models.xlm_roberta import XLMRobertaOnnxConfig
@@ -64,6 +65,7 @@
6465
"bert": {"default": BertOnnxConfig.default},
6566
"distilbert": {"default": DistilBertOnnxConfig.default},
6667
"gpt2": {"default": GPT2OnnxConfig.default, "with_past": GPT2OnnxConfig.with_past},
68+
"longformer": {"default": LongformerOnnxConfig.default},
6769
"roberta": {"default": RobertaOnnxConfig},
6870
"t5": {"default": T5OnnxConfig.default, "with_past": T5OnnxConfig.with_past},
6971
"xlm-roberta": {"default": XLMRobertaOnnxConfig.default},

src/transformers/onnx/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def generate_dummy_inputs(
184184
seq_length = compute_effective_axis_dimension(seq_length, fixed_dimension=8, num_token_to_add=token_to_add)
185185

186186
# Generate dummy inputs according to compute batch and sequence
187-
dummy_input = [[tokenizer.unk_token] * seq_length] * batch_size
188-
return dict(tokenizer(dummy_input, is_split_into_words=True, return_tensors=framework))
187+
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
188+
return dict(tokenizer(dummy_input, return_tensors=framework))
189189

190190

191191
class OnnxConfigWithPast(OnnxConfig, ABC):

0 commit comments

Comments
 (0)