Skip to content

Commit 198c335

Browse files
[Doctests] Fix ignore bug and add more doc tests (huggingface#15911)
* finish speech doc tests * finish * boom * Update src/transformers/models/speech_to_text/modeling_speech_to_text.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1 parent 8529a85 commit 198c335

File tree

10 files changed

+115
-74
lines changed

10 files changed

+115
-74
lines changed

conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,12 @@ def pytest_sessionfinish(session, exitstatus):
6767

6868
OutputChecker = doctest.OutputChecker
6969

70+
7071
class CustomOutputChecker(OutputChecker):
7172
def check_output(self, want, got, optionflags):
72-
if IGNORE_RESULT and optionflags:
73+
if IGNORE_RESULT & optionflags:
7374
return True
7475
return OutputChecker.check_output(self, want, got, optionflags)
7576

77+
7678
doctest.OutputChecker = CustomOutputChecker

src/transformers/models/data2vec/modeling_data2vec_audio.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,21 +55,21 @@
5555

5656
# CTC docstring
5757
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
58-
_CTC_EXPECTED_LOSS = 53.48
58+
_CTC_EXPECTED_LOSS = 66.95
5959

6060
# Audio class docstring
6161
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
62-
_SEQ_CLASS_CHECKPOINT = "superb/data2vec-audio-base-superb-ks"
63-
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
64-
_SEQ_CLASS_EXPECTED_LOSS = 6.54
62+
_SEQ_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-data2vec-seq-class"
63+
_SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
64+
_SEQ_CLASS_EXPECTED_LOSS = 0.69
6565

6666
# Frame class docstring
67-
_FRAME_CLASS_CHECKPOINT = "anton-l/data2vec-audio-base-superb-sd"
68-
_FRAME_EXPECTED_OUTPUT = [0, 0]
67+
_FRAME_CLASS_CHECKPOINT = "hf-internal-testing/tiny-random-data2vec-audio-frame"
68+
_FRAME_EXPECTED_OUTPUT = [1, 1]
6969

7070
# Speaker Verification docstring
71-
_XVECTOR_CHECKPOINT = "anton-l/data2vec-audio-base-superb-sv"
72-
_XVECTOR_EXPECTED_OUTPUT = 0.98
71+
_XVECTOR_CHECKPOINT = "hf-internal-testing/tiny-random-data2vec-xvector"
72+
_XVECTOR_EXPECTED_OUTPUT = 1.0
7373

7474

7575
DATA2VEC_AUDIO_PRETRAINED_MODEL_ARCHIVE_LIST = [

src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -465,22 +465,28 @@ def forward(
465465
Examples:
466466
467467
```python
468-
>>> from transformers import SpeechEncoderDecoderModel, Speech2Text2Processor
468+
>>> from transformers import SpeechEncoderDecoderModel, Wav2Vec2Processor
469469
>>> from datasets import load_dataset
470470
>>> import torch
471471
472-
>>> processor = Speech2Text2Processor.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
473-
>>> model = SpeechEncoderDecoderModel.from_pretrained("facebook/s2t-wav2vec2-large-en-de")
472+
>>> processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15")
473+
>>> model = SpeechEncoderDecoderModel.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15")
474474
475475
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
476476
477477
>>> input_values = processor(ds[0]["audio"]["array"], return_tensors="pt").input_values
478-
>>> decoder_input_ids = torch.tensor([[model.config.decoder.decoder_start_token_id]])
479-
>>> outputs = model(input_values=input_values, decoder_input_ids=decoder_input_ids)
480-
481-
>>> # inference (generation)
478+
>>> # Inference: Translate English speech to German
482479
>>> generated = model.generate(input_values)
483-
>>> translation = processor.batch_decode(generated)
480+
>>> decoded = processor.batch_decode(generated, skip_special_tokens=True)[0]
481+
>>> decoded
482+
'Mr. Quilter ist der Apostel der Mittelschicht und wir freuen uns, sein Evangelium willkommen heißen zu können.'
483+
484+
>>> # Training: Train model on English transcription
485+
>>> with processor.as_target_processor():
486+
... labels = processor(ds[0]["text"], return_tensors="pt").input_ids
487+
488+
>>> loss = model(input_values, labels=labels).loss
489+
>>> loss.backward()
484490
```"""
485491
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
486492

src/transformers/models/speech_to_text/modeling_speech_to_text.py

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,7 @@
2424
from torch.nn import CrossEntropyLoss
2525

2626
from ...activations import ACT2FN
27-
from ...file_utils import (
28-
add_code_sample_docstrings,
29-
add_start_docstrings,
30-
add_start_docstrings_to_model_forward,
31-
replace_return_docstrings,
32-
)
27+
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
3328
from ...modeling_outputs import (
3429
BaseModelOutput,
3530
BaseModelOutputWithPastAndCrossAttentions,
@@ -44,8 +39,6 @@
4439
logger = logging.get_logger(__name__)
4540

4641
_CONFIG_FOR_DOC = "Speech2TextConfig"
47-
_TOKENIZER_FOR_DOC = "Speech2TextTokenizer"
48-
_CHECKPOINT_FOR_DOC = "facebook/s2t-small-librispeech-asr"
4942

5043

5144
SPEECH_TO_TEXT_PRETRAINED_MODEL_ARCHIVE_LIST = [
@@ -780,7 +773,7 @@ def forward(
780773
attention_mask = self._get_feature_vector_attention_mask(inputs_embeds.shape[1], attention_mask)
781774
padding_mask = attention_mask.ne(1).long()
782775
else:
783-
padding_mask = torch.zeros_like(inputs_embeds, dtype=torch.long)
776+
padding_mask = torch.zeros(inputs_embeds.shape[:2], dtype=torch.long, device=inputs_embeds.device)
784777

785778
embed_pos = self.embed_positions(padding_mask)
786779

@@ -1144,12 +1137,7 @@ def get_decoder(self):
11441137
return self.decoder
11451138

11461139
@add_start_docstrings_to_model_forward(SPEECH_TO_TEXT_INPUTS_DOCSTRING)
1147-
@add_code_sample_docstrings(
1148-
processor_class=_TOKENIZER_FOR_DOC,
1149-
checkpoint=_CHECKPOINT_FOR_DOC,
1150-
output_type=Seq2SeqModelOutput,
1151-
config_class=_CONFIG_FOR_DOC,
1152-
)
1140+
@replace_return_docstrings(output_type=Seq2SeqLMOutput, config_class=_CONFIG_FOR_DOC)
11531141
def forward(
11541142
self,
11551143
input_features=None,
@@ -1167,6 +1155,28 @@ def forward(
11671155
output_hidden_states=None,
11681156
return_dict=None,
11691157
):
1158+
r"""
1159+
Returns:
1160+
1161+
Example:
1162+
1163+
```python
1164+
>>> import torch
1165+
>>> from transformers import Speech2TextModel, Speech2TextFeatureExtractor
1166+
>>> from datasets import load_dataset
1167+
1168+
>>> model = Speech2TextModel.from_pretrained("facebook/s2t-small-librispeech-asr")
1169+
>>> feature_extractor = Speech2TextFeatureExtractor.from_pretrained("facebook/s2t-small-librispeech-asr")
1170+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1171+
>>> input_features = feature_extractor(
1172+
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt"
1173+
>>> ).input_features
1174+
>>> decoder_input_ids = torch.tensor([[1, 1]]) * model.config.decoder_start_token_id
1175+
>>> last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
1176+
>>> list(last_hidden_state.shape)
1177+
[1, 2, 256]
1178+
```"""
1179+
11701180
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
11711181
output_hidden_states = (
11721182
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
@@ -1305,27 +1315,22 @@ def forward(
13051315
>>> import torch
13061316
>>> from transformers import Speech2TextProcessor, Speech2TextForConditionalGeneration
13071317
>>> from datasets import load_dataset
1308-
>>> import soundfile as sf
13091318
13101319
>>> model = Speech2TextForConditionalGeneration.from_pretrained("facebook/s2t-small-librispeech-asr")
13111320
>>> processor = Speech2TextProcessor.from_pretrained("facebook/s2t-small-librispeech-asr")
13121321
13131322
1314-
>>> def map_to_array(batch):
1315-
... speech, _ = sf.read(batch["file"])
1316-
... batch["speech"] = speech
1317-
... return batch
1318-
1319-
13201323
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1321-
>>> ds = ds.map(map_to_array)
13221324
13231325
>>> input_features = processor(
1324-
... ds["speech"][0], sampling_rate=16000, return_tensors="pt"
1325-
>>> ).input_features # Batch size 1
1326+
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt"
1327+
>>> ).input_features
1328+
13261329
>>> generated_ids = model.generate(inputs=input_features)
13271330
1328-
>>> transcription = processor.batch_decode(generated_ids)
1331+
>>> transcription = processor.batch_decode(generated_ids)[0]
1332+
>>> transcription
1333+
'mister quilter is the apostle of the middle classes and we are glad to welcome his gospel'
13291334
```"""
13301335
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
13311336

src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,12 @@
3535
logger = logging.get_logger(__name__)
3636

3737
_CONFIG_FOR_DOC = "Speech2Text2Config"
38-
_TOKENIZER_FOR_DOC = "Speech2Text2Tokenizer"
39-
_CHECKPOINT_FOR_DOC = "facebook/s2t-small-librispeech-asr"
38+
_CHECKPOINT_FOR_DOC = "facebook/s2t-wav2vec2-large-en-de"
4039

4140

4241
SPEECH_TO_TEXT_2_PRETRAINED_MODEL_ARCHIVE_LIST = [
43-
"facebook/s2t-small-librispeech-asr",
44-
# See all Speech2Text2 models at https://huggingface.co/models?filter=speech_to_text
42+
"facebook/s2t-wav2vec2-large-en-de",
43+
# See all Speech2Text2 models at https://huggingface.co/models?filter=speech2text2
4544
]
4645

4746

@@ -865,13 +864,34 @@ def forward(
865864
... Wav2Vec2Model,
866865
... Speech2Text2Config,
867866
... Wav2Vec2Config,
867+
... Wav2Vec2FeatureExtractor,
868+
... Speech2Text2Tokenizer,
868869
... )
870+
>>> from datasets import load_dataset
871+
872+
>>> feature_extractor = Wav2Vec2FeatureExtractor()
873+
>>> tokenizer = Speech2Text2Tokenizer.from_pretrained(_CHECKPOINT_FOR_DOC)
869874
870875
>>> encoder = Wav2Vec2Model(Wav2Vec2Config())
871876
>>> decoder = Speech2Text2ForCausalLM(Speech2Text2Config())
872-
# init speech2text model
877+
# init random speech2text model
873878
874879
>>> model = SpeechEncoderDecoderModel(encoder=encoder, decoder=decoder)
880+
>>> model.config.pad_token_id = tokenizer.pad_token_id
881+
>>> model.config.decoder_start_token_id = tokenizer.bos_token_id
882+
# pre-process inputs and labels
883+
884+
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
885+
>>> input_values = feature_extractor(
886+
... ds[0]["audio"]["array"], sampling_rate=ds[0]["audio"]["sampling_rate"], return_tensors="pt"
887+
>>> ).input_values # Batch size 1
888+
>>> decoder_input_ids = tokenizer(ds[0]["text"], return_tensors="pt").input_ids
889+
# compute loss
890+
891+
>>> loss = model(inputs=input_values, labels=decoder_input_ids).loss
892+
# backprop loss
893+
894+
>>> loss.backward()
875895
```"""
876896

877897
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions

src/transformers/models/wav2vec2/modeling_wav2vec2.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1478,17 +1478,8 @@ def forward(
14781478
>>> feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained("patrickvonplaten/wav2vec2-base")
14791479
>>> model = Wav2Vec2ForPreTraining.from_pretrained("patrickvonplaten/wav2vec2-base")
14801480
1481-
1482-
>>> def map_to_array(batch):
1483-
... speech, _ = sf.read(batch["file"])
1484-
... batch["speech"] = speech
1485-
... return batch
1486-
1487-
14881481
>>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
1489-
>>> ds = ds.map(map_to_array)
1490-
1491-
>>> input_values = feature_extractor(ds["speech"][0], return_tensors="pt").input_values # Batch size 1
1482+
>>> input_values = feature_extractor(ds[0]["audio"]["array"], return_tensors="pt").input_values # Batch size 1
14921483
14931484
>>> # compute masked indices
14941485
>>> batch_size, raw_sequence_length = input_values.shape

src/transformers/models/wav2vec2/tokenization_wav2vec2.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -566,17 +566,15 @@ def decode(
566566
>>> word_offsets = [
567567
... {
568568
... "word": d["word"],
569-
... "start_time": d["start_offset"] * time_offset,
570-
... "end_time": d["end_offset"] * time_offset,
569+
... "start_time": round(d["start_offset"] * time_offset, 2),
570+
... "end_time": round(d["end_offset"] * time_offset, 2),
571571
... }
572572
... for d in outputs.word_offsets
573573
... ]
574574
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
575575
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
576-
>>> word_offset
577-
>>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES',
578-
>>> # 'start_time': 1.64, 'end_time': 1.90}, {'word': 'MILISANDRA',
579-
>>> # 'start_time': 2.26, 'end_time': 2.9}, {'word': 'LOOK', 'start_time': 3.0, 'end_time': 3.16}, ...
576+
>>> word_offsets[:3]
577+
[{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES', 'start_time': 1.64, 'end_time': 1.9}, {'word': 'MILISANDRA', 'start_time': 2.26, 'end_time': 2.9}]
580578
```"""
581579
# Convert inputs to python lists
582580
token_ids = to_py_obj(token_ids)

src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def decode(
401401
402402
```python
403403
>>> # Let's see how to retrieve time steps for a model
404-
>>> from transformers import AutoTokenizer, AutoFeatureExtractor, AutoModelForCTC
404+
>>> from transformers import AutoTokenizer, AutoProcessor, AutoModelForCTC
405405
>>> from datasets import load_dataset
406406
>>> import datasets
407407
>>> import torch
@@ -417,29 +417,27 @@ def decode(
417417
>>> sample = next(dataset_iter)
418418
419419
>>> # forward sample through model to get greedily predicted transcription ids
420-
>>> input_values = feature_extractor(sample["audio"]["array"], return_tensors="pt").input_values
420+
>>> input_values = processor(sample["audio"]["array"], return_tensors="pt").input_values
421421
>>> with torch.no_grad():
422422
... logits = model(input_values).logits[0].cpu().numpy()
423423
424424
>>> # retrieve word stamps (analogous commands for `output_char_offsets`)
425-
>>> outputs = tokenizer.decode(logits, output_word_offsets=True)
425+
>>> outputs = processor.decode(logits, output_word_offsets=True)
426426
>>> # compute `time_offset` in seconds as product of downsampling ratio and sampling_rate
427-
>>> time_offset = model.config.inputs_to_logits_ratio / feature_extractor.sampling_rate
427+
>>> time_offset = model.config.inputs_to_logits_ratio / processor.feature_extractor.sampling_rate
428428
429429
>>> word_offsets = [
430430
... {
431431
... "word": d["word"],
432-
... "start_time": d["start_offset"] * time_offset,
433-
... "end_time": d["end_offset"] * time_offset,
432+
... "start_time": round(d["start_offset"] * time_offset, 2),
433+
... "end_time": round(d["end_offset"] * time_offset, 2),
434434
... }
435435
... for d in outputs.word_offsets
436436
... ]
437437
>>> # compare word offsets with audio `common_voice_en_100038.mp3` online on the dataset viewer:
438438
>>> # https://huggingface.co/datasets/common_voice/viewer/en/train
439-
>>> word_offset
440-
>>> # [{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES',
441-
>>> # 'start_time': 1.64, 'end_time': 1.88}, {'word': 'A',
442-
>>> # 'start_time': 2.12, 'end_time': 2.14}, {'word': 'MILE', 'start_time': 2.26, 'end_time': 2.46}, ...
439+
>>> word_offsets[:4]
440+
[{'word': 'WHY', 'start_time': 1.42, 'end_time': 1.54}, {'word': 'DOES', 'start_time': 1.64, 'end_time': 1.88}, {'word': 'A', 'start_time': 2.12, 'end_time': 2.14}, {'word': 'MILE', 'start_time': 2.26, 'end_time': 2.46}]
443441
```"""
444442

445443
from pyctcdecode.constants import (

tests/speech_to_text/test_modeling_speech_to_text.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,17 @@ def get_subsampled_output_lengths(self, input_lengths):
185185

186186
return input_lengths
187187

188+
def create_and_check_model_forward(self, config, inputs_dict):
189+
model = Speech2TextModel(config=config).to(torch_device).eval()
190+
191+
input_features = inputs_dict["input_features"]
192+
decoder_input_ids = inputs_dict["decoder_input_ids"]
193+
194+
# first forward pass
195+
last_hidden_state = model(input_features, decoder_input_ids=decoder_input_ids).last_hidden_state
196+
197+
self.parent.assertTrue(last_hidden_state.shape, (13, 7, 16))
198+
188199
def create_and_check_decoder_model_past_large_inputs(self, config, inputs_dict):
189200
model = Speech2TextModel(config=config).get_decoder().to(torch_device).eval()
190201
input_ids = inputs_dict["decoder_input_ids"]
@@ -284,6 +295,10 @@ def test_save_load_strict(self):
284295
model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
285296
self.assertEqual(info["missing_keys"], [])
286297

298+
def test_model_forward(self):
299+
config_and_inputs = self.model_tester.prepare_config_and_inputs()
300+
self.model_tester.create_and_check_model_forward(*config_and_inputs)
301+
287302
def test_decoder_model_past_with_large_inputs(self):
288303
config_and_inputs = self.model_tester.prepare_config_and_inputs()
289304
self.model_tester.create_and_check_decoder_model_past_large_inputs(*config_and_inputs)

utils/documentation_tests.txt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
src/transformers/models/wav2vec2/modeling_wav2vec2.py
2+
src/transformers/models/wav2vec2/tokenization_wav2vec2.py
3+
src/transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py
24
src/transformers/models/hubert/modeling_hubert.py
35
src/transformers/models/wavlm/modeling_wavlm.py
46
src/transformers/models/unispeech/modeling_unispeech.py
57
src/transformers/models/unispeech_sat/modeling_unispeech_sat.py
68
src/transformers/models/sew/modeling_sew.py
79
src/transformers/models/sew_d/modeling_sew_d.py
10+
src/transformers/models/speech_to_text_2/modeling_speech_to_text_2.py
11+
src/transformers/models/speech_to_text/modeling_speech_to_text.py
12+
src/transformers/models/speech_encoder_decoder/modeling_speech_enocder_decoder.py
13+
src/transformers/models/data2vec/modeling_data2vec_audio.py
814
docs/source/quicktour.mdx
915
docs/source/task_summary.mdx

0 commit comments

Comments
 (0)