Skip to content

Commit f665efb

Browse files
committed
WIP Wav2Vec2
1 parent 4516171 commit f665efb

File tree

1 file changed

+43
-0
lines changed

1 file changed

+43
-0
lines changed

src/transformers/models/wav2vec2/configuration_wav2vec2.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
""" Wav2Vec2 model configuration """
16+
from collections import OrderedDict
17+
from typing import Any, Mapping, Optional
18+
19+
from transformers import PreTrainedTokenizer, TensorType
1620

1721
from ...configuration_utils import PretrainedConfig
22+
from ...onnx import OnnxConfigWithPast
1823
from ...utils import logging
1924

2025

@@ -256,3 +261,41 @@ def __init__(
256261
# ctc loss
257262
self.ctc_loss_reduction = ctc_loss_reduction
258263
self.ctc_zero_infinity = ctc_zero_infinity
264+
265+
266+
class Wav2Vec2OnnxConfig(OnnxConfigWithPast):
267+
DEFAULT_FIXED_SEQUENCE = 1024
268+
269+
@property
270+
def inputs(self) -> Mapping[str, Mapping[int, str]]:
271+
return {
272+
"input_values": {0: "batch", 1: "sequence"},
273+
"attention_mask": {0: "batch", 1: "sequence"},
274+
}
275+
276+
@property
277+
def outputs(self) -> Mapping[str, Mapping[int, str]]:
278+
return {
279+
"last_hidden_state": {0: "batch", 1: "sequence"},
280+
"extract_features": {0: "batch", 1: "sequence"},
281+
}
282+
283+
@property
284+
def default_sequence_length(self) -> int:
285+
return Wav2Vec2OnnxConfig.DEFAULT_FIXED_SEQUENCE
286+
287+
def generate_dummy_inputs(
288+
self,
289+
tokenizer: PreTrainedTokenizer,
290+
batch_size: int = -1,
291+
seq_length: int = -1,
292+
is_pair: bool = False,
293+
framework: Optional[TensorType] = None,
294+
) -> Mapping[str, Any]:
295+
encodings = super().generate_dummy_inputs(tokenizer, batch_size, seq_length, is_pair, framework)
296+
ordered_encodings = OrderedDict({})
297+
298+
# Replace input_ids with input_values
299+
ordered_encodings["input_values"] = encodings["input_ids"].long()
300+
ordered_encodings["attention_mask"] = encodings["attention_mask"]
301+
return ordered_encodings

0 commit comments

Comments
 (0)