Skip to content

Commit 4516171

Browse files
committed
Use OrderedDict for dummy inputs
1 parent 4eebfda commit 4516171

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

src/transformers/onnx/config.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from abc import ABC, abstractmethod
15+
from collections import OrderedDict
1516
from typing import Any, Mapping, Optional
1617

1718
from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
@@ -84,6 +85,25 @@ def values_override(self) -> Optional[Mapping[str, Any]]:
8485

8586
return None
8687

88+
@property
89+
def default_batch_size(self) -> int:
90+
"""
91+
The default batch size to use if no other indication
92+
Returns:
93+
Integer > 0
94+
"""
95+
# Using 2 avoid ONNX making assumption about single sample batch
96+
return OnnxConfig.DEFAULT_FIXED_BATCH
97+
98+
@property
99+
def default_sequence_length(self) -> int:
100+
"""
101+
The default sequence length to use if no other indication
102+
Returns:
103+
Integer > 0
104+
"""
105+
return OnnxConfig.DEFAULT_FIXED_SEQUENCE
106+
87107
@property
88108
def default_onnx_opset(self) -> int:
89109
"""
@@ -184,18 +204,18 @@ def generate_dummy_inputs(
184204
) -> Mapping[str, Any]:
185205
# If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
186206
batch_size = compute_effective_axis_dimension(
187-
batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0
207+
batch_size, fixed_dimension=self.default_batch_size, num_token_to_add=0
188208
)
189209

190210
# If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
191211
token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
192212

193213
# When use_past the caching mechanism requires inputs to be only 1 single token
194-
fixed_sequence_length = 1 if self.use_past else OnnxConfig.DEFAULT_FIXED_SEQUENCE
214+
fixed_sequence_length = 1 if self.use_past else self.default_sequence_length
195215
seq_length = compute_effective_axis_dimension(
196216
seq_length, fixed_dimension=fixed_sequence_length, num_token_to_add=token_to_add
197217
)
198218

199219
# Generate dummy inputs according to compute batch and sequence
200220
dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
201-
return dict(tokenizer(dummy_input, return_tensors=framework))
221+
return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework)))

0 commit comments

Comments
 (0)