|
13 | 13 | # See the License for the specific language governing permissions and
|
14 | 14 | # limitations under the License.
|
15 | 15 | """ Wav2Vec2 model configuration """
|
| 16 | +from collections import OrderedDict |
| 17 | +from typing import Any, Mapping, Optional |
| 18 | + |
| 19 | +from transformers import PreTrainedTokenizer, TensorType |
16 | 20 |
|
17 | 21 | from ...configuration_utils import PretrainedConfig
|
| 22 | +from ...onnx import OnnxConfigWithPast |
18 | 23 | from ...utils import logging
|
19 | 24 |
|
20 | 25 |
|
@@ -256,3 +261,41 @@ def __init__(
|
256 | 261 | # ctc loss
|
257 | 262 | self.ctc_loss_reduction = ctc_loss_reduction
|
258 | 263 | self.ctc_zero_infinity = ctc_zero_infinity
|
| 264 | + |
| 265 | + |
| 266 | +class Wav2Vec2OnnxConfig(OnnxConfigWithPast): |
| 267 | + DEFAULT_FIXED_SEQUENCE = 1024 |
| 268 | + |
| 269 | + @property |
| 270 | + def inputs(self) -> Mapping[str, Mapping[int, str]]: |
| 271 | + return { |
| 272 | + "input_values": {0: "batch", 1: "sequence"}, |
| 273 | + "attention_mask": {0: "batch", 1: "sequence"}, |
| 274 | + } |
| 275 | + |
| 276 | + @property |
| 277 | + def outputs(self) -> Mapping[str, Mapping[int, str]]: |
| 278 | + return { |
| 279 | + "last_hidden_state": {0: "batch", 1: "sequence"}, |
| 280 | + "extract_features": {0: "batch", 1: "sequence"}, |
| 281 | + } |
| 282 | + |
| 283 | + @property |
| 284 | + def default_sequence_length(self) -> int: |
| 285 | + return Wav2Vec2OnnxConfig.DEFAULT_FIXED_SEQUENCE |
| 286 | + |
| 287 | + def generate_dummy_inputs( |
| 288 | + self, |
| 289 | + tokenizer: PreTrainedTokenizer, |
| 290 | + batch_size: int = -1, |
| 291 | + seq_length: int = -1, |
| 292 | + is_pair: bool = False, |
| 293 | + framework: Optional[TensorType] = None, |
| 294 | + ) -> Mapping[str, Any]: |
| 295 | + encodings = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework) |
| 296 | + ordered_encodings = OrderedDict({}) |
| 297 | + |
| 298 | + # Replace input_ids with input_values |
| 299 | + ordered_encodings["input_values"] = encodings["input_ids"].long() |
| 300 | + ordered_encodings["attention_mask"] = encodings["attention_mask"] |
| 301 | + return ordered_encodings |
0 commit comments