|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | from abc import ABC, abstractmethod
|
| 15 | +from collections import OrderedDict |
15 | 16 | from typing import Any, Mapping, Optional
|
16 | 17 |
|
17 | 18 | from transformers import PretrainedConfig, PreTrainedTokenizer, TensorType
|
@@ -84,6 +85,25 @@ def values_override(self) -> Optional[Mapping[str, Any]]:
|
84 | 85 |
|
85 | 86 | return None
|
86 | 87 |
|
| 88 | + @property |
| 89 | + def default_batch_size(self) -> int: |
| 90 | + """ |
| 91 | + The default batch size to use if no other indication |
| 92 | + Returns: |
| 93 | + Integer > 0 |
| 94 | + """ |
| 95 | + # Using 2 avoid ONNX making assumption about single sample batch |
| 96 | + return OnnxConfig.DEFAULT_FIXED_BATCH |
| 97 | + |
| 98 | + @property |
| 99 | + def default_sequence_length(self) -> int: |
| 100 | + """ |
| 101 | + The default sequence length to use if no other indication |
| 102 | + Returns: |
| 103 | + Integer > 0 |
| 104 | + """ |
| 105 | + return OnnxConfig.DEFAULT_FIXED_SEQUENCE |
| 106 | + |
87 | 107 | @property
|
88 | 108 | def default_onnx_opset(self) -> int:
|
89 | 109 | """
|
@@ -184,18 +204,18 @@ def generate_dummy_inputs(
|
184 | 204 | ) -> Mapping[str, Any]:
|
185 | 205 | # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
|
186 | 206 | batch_size = compute_effective_axis_dimension(
|
187 |
| - batch_size, fixed_dimension=OnnxConfig.DEFAULT_FIXED_BATCH, num_token_to_add=0 |
| 207 | + batch_size, fixed_dimension=self.default_batch_size, num_token_to_add=0 |
188 | 208 | )
|
189 | 209 |
|
190 | 210 | # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
|
191 | 211 | token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
|
192 | 212 |
|
193 | 213 | # When use_past the caching mechanism requires inputs to be only 1 single token
|
194 |
| - fixed_sequence_length = 1 if self.use_past else OnnxConfig.DEFAULT_FIXED_SEQUENCE |
| 214 | + fixed_sequence_length = 1 if self.use_past else self.default_sequence_length |
195 | 215 | seq_length = compute_effective_axis_dimension(
|
196 | 216 | seq_length, fixed_dimension=fixed_sequence_length, num_token_to_add=token_to_add
|
197 | 217 | )
|
198 | 218 |
|
199 | 219 | # Generate dummy inputs according to compute batch and sequence
|
200 | 220 | dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
|
201 |
| - return dict(tokenizer(dummy_input, return_tensors=framework)) |
| 221 | + return OrderedDict(dict(tokenizer(dummy_input, return_tensors=framework))) |
0 commit comments