Skip to content

Commit 9f831bd

Browse files
[DocTests Speech] Add doc tests for all speech models (huggingface#15031)
* fix_torch_device_generate_test * remove @ * doc tests * up * up * fix doctests * adapt files * finish refactor * up * save intermediate * add more logic * new change * improve * next try * next try * next try * next try * fix final spaces * fix final spaces * improve * renaming * correct more bugs * finish wavlm * add comment * run on test runner * finish all speech models * adapt * finish
1 parent 4df6950 commit 9f831bd

File tree

11 files changed

+410
-120
lines changed

11 files changed

+410
-120
lines changed

.github/workflows/doctests.yml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ env:
1919

2020
jobs:
2121
run_doctests:
22-
runs-on: [self-hosted, docker-gpu, single-gpu]
22+
runs-on: [self-hosted, docker-gpu-test, single-gpu]
2323
container:
2424
image: pytorch/pytorch:1.9.0-cuda11.1-cudnn8-runtime
2525
options: --gpus 0 --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/mnt/cache/
@@ -35,8 +35,16 @@ jobs:
3535
run: |
3636
apt -y update && apt install -y libsndfile1-dev
3737
pip install --upgrade pip
38-
pip install .[dev]
38+
pip install .[testing,torch-speech]
39+
40+
- name: Prepare files for doctests
41+
run: |
42+
python utils/prepare_for_doc_test.py src docs
3943
4044
- name: Run doctests
4145
run: |
42-
pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure
46+
pytest --doctest-modules $(cat utils/documentation_tests.txt) -sv --doctest-continue-on-failure --doctest-glob="*.mdx"
47+
48+
- name: Clean files after doctests
49+
run: |
50+
python utils/prepare_for_doc_test.py src docs --remove_new_line

src/transformers/file_utils.py

Lines changed: 60 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,19 +1127,24 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None):
11271127
11281128
```python
11291129
>>> from transformers import {processor_class}, {model_class}
1130+
>>> import torch
11301131
>>> from datasets import load_dataset
11311132
11321133
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
1134+
>>> dataset = dataset.sort("id")
11331135
>>> sampling_rate = dataset.features["audio"].sampling_rate
11341136
11351137
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
11361138
>>> model = {model_class}.from_pretrained("{checkpoint}")
11371139
11381140
>>> # audio file is decoded on the fly
11391141
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
1140-
>>> outputs = model(**inputs)
1142+
>>> with torch.no_grad():
1143+
... outputs = model(**inputs)
11411144
11421145
>>> last_hidden_states = outputs.last_hidden_state
1146+
>>> list(last_hidden_states.shape)
1147+
{expected_output}
11431148
```
11441149
"""
11451150

@@ -1152,24 +1157,32 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None):
11521157
>>> import torch
11531158
11541159
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
1160+
>>> dataset = dataset.sort("id")
11551161
>>> sampling_rate = dataset.features["audio"].sampling_rate
11561162
11571163
>>> processor = {processor_class}.from_pretrained("{checkpoint}")
11581164
>>> model = {model_class}.from_pretrained("{checkpoint}")
11591165
11601166
>>> # audio file is decoded on the fly
11611167
>>> inputs = processor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
1162-
>>> logits = model(**inputs).logits
1168+
>>> with torch.no_grad():
1169+
... logits = model(**inputs).logits
11631170
>>> predicted_ids = torch.argmax(logits, dim=-1)
11641171
11651172
>>> # transcribe speech
11661173
>>> transcription = processor.batch_decode(predicted_ids)
1174+
>>> transcription[0]
1175+
{expected_output}
1176+
```
11671177
1168-
>>> # compute loss
1178+
```python
11691179
>>> with processor.as_target_processor():
11701180
... inputs["labels"] = processor(dataset[0]["text"], return_tensors="pt").input_ids
11711181
1182+
>>> # compute loss
11721183
>>> loss = model(**inputs).loss
1184+
>>> round(loss.item(), 2)
1185+
{expected_loss}
11731186
```
11741187
"""
11751188

@@ -1182,21 +1195,31 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None):
11821195
>>> import torch
11831196
11841197
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
1198+
>>> dataset = dataset.sort("id")
11851199
>>> sampling_rate = dataset.features["audio"].sampling_rate
11861200
11871201
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
11881202
>>> model = {model_class}.from_pretrained("{checkpoint}")
11891203
11901204
>>> # audio file is decoded on the fly
1191-
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt")
1192-
>>> logits = model(**inputs).logits
1193-
>>> predicted_class_ids = torch.argmax(logits, dim=-1)
1205+
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], sampling_rate=sampling_rate, return_tensors="pt")
1206+
1207+
>>> with torch.no_grad():
1208+
... logits = model(**inputs).logits
1209+
1210+
>>> predicted_class_ids = torch.argmax(logits, dim=-1).item()
11941211
>>> predicted_label = model.config.id2label[predicted_class_ids]
1212+
>>> predicted_label
1213+
{expected_output}
1214+
```
11951215
1216+
```python
11961217
>>> # compute loss - target_label is e.g. "down"
11971218
>>> target_label = model.config.id2label[0]
11981219
>>> inputs["labels"] = torch.tensor([model.config.label2id[target_label]])
11991220
>>> loss = model(**inputs).loss
1221+
>>> round(loss.item(), 2)
1222+
{expected_loss}
12001223
```
12011224
"""
12021225

@@ -1210,17 +1233,22 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None):
12101233
>>> import torch
12111234
12121235
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
1236+
>>> dataset = dataset.sort("id")
12131237
>>> sampling_rate = dataset.features["audio"].sampling_rate
12141238
12151239
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
12161240
>>> model = {model_class}.from_pretrained("{checkpoint}")
12171241
12181242
>>> # audio file is decoded on the fly
1219-
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt")
1220-
>>> logits = model(**inputs).logits
1243+
>>> inputs = feature_extractor(dataset[0]["audio"]["array"], return_tensors="pt", sampling_rate=sampling_rate)
1244+
>>> with torch.no_grad():
1245+
... logits = model(**inputs).logits
1246+
12211247
>>> probabilities = torch.sigmoid(logits[0])
12221248
>>> # labels is a one-hot array of shape (num_frames, num_speakers)
12231249
>>> labels = (probabilities > 0.5).long()
1250+
>>> labels[0].tolist()
1251+
{expected_output}
12241252
```
12251253
"""
12261254

@@ -1234,14 +1262,19 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None):
12341262
>>> import torch
12351263
12361264
>>> dataset = load_dataset("hf-internal-testing/librispeech_asr_demo", "clean", split="validation")
1265+
>>> dataset = dataset.sort("id")
12371266
>>> sampling_rate = dataset.features["audio"].sampling_rate
12381267
12391268
>>> feature_extractor = {processor_class}.from_pretrained("{checkpoint}")
12401269
>>> model = {model_class}.from_pretrained("{checkpoint}")
12411270
12421271
>>> # audio file is decoded on the fly
1243-
>>> inputs = feature_extractor(dataset[:2]["audio"]["array"], return_tensors="pt")
1244-
>>> embeddings = model(**inputs).embeddings
1272+
>>> inputs = feature_extractor(
1273+
... [d["array"] for d in dataset[:2]["audio"]], sampling_rate=sampling_rate, return_tensors="pt", padding=True
1274+
... )
1275+
>>> with torch.no_grad():
1276+
... embeddings = model(**inputs).embeddings
1277+
12451278
>>> embeddings = torch.nn.functional.normalize(embeddings, dim=-1).cpu()
12461279
12471280
>>> # the resulting embeddings can be used for cosine similarity-based retrieval
@@ -1250,6 +1283,8 @@ def _prepare_output_docstrings(output_type, config_class, min_indent=None):
12501283
>>> threshold = 0.7 # the optimal threshold is dataset-dependent
12511284
>>> if similarity < threshold:
12521285
... print("Speakers are not the same!")
1286+
>>> round(similarity.item(), 2)
1287+
{expected_output}
12531288
```
12541289
"""
12551290

@@ -1553,9 +1588,11 @@ def add_code_sample_docstrings(
15531588
checkpoint=None,
15541589
output_type=None,
15551590
config_class=None,
1556-
mask=None,
1591+
mask="[MASK]",
15571592
model_cls=None,
1558-
modality=None
1593+
modality=None,
1594+
expected_output="",
1595+
expected_loss="",
15591596
):
15601597
def docstring_decorator(fn):
15611598
# model_class defaults to function's class if not specified otherwise
@@ -1568,7 +1605,17 @@ def docstring_decorator(fn):
15681605
else:
15691606
sample_docstrings = PT_SAMPLE_DOCSTRINGS
15701607

1571-
doc_kwargs = dict(model_class=model_class, processor_class=processor_class, checkpoint=checkpoint)
1608+
# putting all kwargs for docstrings in a dict to be used
1609+
# with the `.format(**doc_kwargs)`. Note that string might
1610+
# be formatted with non-existing keys, which is fine.
1611+
doc_kwargs = dict(
1612+
model_class=model_class,
1613+
processor_class=processor_class,
1614+
checkpoint=checkpoint,
1615+
mask=mask,
1616+
expected_output=expected_output,
1617+
expected_loss=expected_loss,
1618+
)
15721619

15731620
if "SequenceClassification" in model_class and modality == "audio":
15741621
code_sample = sample_docstrings["AudioClassification"]
@@ -1581,7 +1628,6 @@ def docstring_decorator(fn):
15811628
elif "MultipleChoice" in model_class:
15821629
code_sample = sample_docstrings["MultipleChoice"]
15831630
elif "MaskedLM" in model_class or model_class in ["FlaubertWithLMHeadModel", "XLMWithLMHeadModel"]:
1584-
doc_kwargs["mask"] = "[MASK]" if mask is None else mask
15851631
code_sample = sample_docstrings["MaskedLM"]
15861632
elif "LMHead" in model_class or "CausalLM" in model_class:
15871633
code_sample = sample_docstrings["LMHead"]

src/transformers/models/hubert/modeling_hubert.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,29 @@
4040

4141
logger = logging.get_logger(__name__)
4242

43-
_CONFIG_FOR_DOC = "HubertConfig"
44-
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
45-
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
4643
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
4744

48-
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
4945

5046
_HIDDEN_STATES_START_POSITION = 1
5147

48+
# General docstring
49+
_CONFIG_FOR_DOC = "HubertConfig"
50+
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
51+
52+
# Base docstring
53+
_CHECKPOINT_FOR_DOC = "facebook/hubert-large-ls960-ft"
54+
_EXPECTED_OUTPUT_SHAPE = [1, 292, 768]
55+
56+
# CTC docstring
57+
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTLE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
58+
_CTC_EXPECTED_LOSS = 22.68
59+
60+
# Audio class docstring
61+
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
62+
_SEQ_CLASS_CHECKPOINT = "superb/hubert-base-superb-ks"
63+
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
64+
_SEQ_CLASS_EXPECTED_LOSS = 8.53
65+
5266

5367
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
5468
"facebook/hubert-base-ls960",
@@ -1098,6 +1112,8 @@ def freeze_feature_encoder(self):
10981112
checkpoint=_CHECKPOINT_FOR_DOC,
10991113
output_type=CausalLMOutput,
11001114
config_class=_CONFIG_FOR_DOC,
1115+
expected_output=_CTC_EXPECTED_OUTPUT,
1116+
expected_loss=_CTC_EXPECTED_LOSS,
11011117
)
11021118
def forward(
11031119
self,
@@ -1228,6 +1244,8 @@ def freeze_base_model(self):
12281244
output_type=SequenceClassifierOutput,
12291245
config_class=_CONFIG_FOR_DOC,
12301246
modality="audio",
1247+
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1248+
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
12311249
)
12321250
def forward(
12331251
self,

src/transformers/models/sew/modeling_sew.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,33 @@
3636

3737
logger = logging.get_logger(__name__)
3838

39-
_CONFIG_FOR_DOC = "SEWConfig"
40-
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k"
4139
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
4240
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
4341

44-
_SEQ_CLASS_CHECKPOINT = "asapp/sew-tiny-100k"
4542

4643
_HIDDEN_STATES_START_POSITION = 1
4744

4845

46+
# General docstring
47+
_CONFIG_FOR_DOC = "SEWConfig"
48+
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
49+
50+
# Base docstring
51+
_CHECKPOINT_FOR_DOC = "asapp/sew-tiny-100k-ft-ls100h"
52+
_EXPECTED_OUTPUT_SHAPE = [1, 292, 512]
53+
54+
# CTC docstring
55+
_CTC_EXPECTED_OUTPUT = (
56+
"'MISTER QUILTER IS THE APPOSTILE OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPOLLE'"
57+
)
58+
_CTC_EXPECTED_LOSS = 0.42
59+
60+
# Audio class docstring
61+
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
62+
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-mid-100k-ft-keyword-spotting"
63+
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
64+
_SEQ_CLASS_EXPECTED_LOSS = 9.52
65+
4966
SEW_PRETRAINED_MODEL_ARCHIVE_LIST = [
5067
"asapp/sew-tiny-100k",
5168
"asapp/sew-small-100k",
@@ -879,6 +896,7 @@ def _mask_hidden_states(
879896
output_type=BaseModelOutput,
880897
config_class=_CONFIG_FOR_DOC,
881898
modality="audio",
899+
expected_output=_EXPECTED_OUTPUT_SHAPE,
882900
)
883901
def forward(
884902
self,
@@ -978,6 +996,8 @@ def freeze_feature_encoder(self):
978996
checkpoint=_CHECKPOINT_FOR_DOC,
979997
output_type=CausalLMOutput,
980998
config_class=_CONFIG_FOR_DOC,
999+
expected_output=_CTC_EXPECTED_OUTPUT,
1000+
expected_loss=_CTC_EXPECTED_LOSS,
9811001
)
9821002
def forward(
9831003
self,
@@ -1108,6 +1128,8 @@ def freeze_base_model(self):
11081128
output_type=SequenceClassifierOutput,
11091129
config_class=_CONFIG_FOR_DOC,
11101130
modality="audio",
1131+
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1132+
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
11111133
)
11121134
def forward(
11131135
self,

src/transformers/models/sew_d/modeling_sew_d.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,14 +37,27 @@
3737

3838
logger = logging.get_logger(__name__)
3939

40+
41+
_HIDDEN_STATES_START_POSITION = 1
42+
43+
44+
# General docstring
4045
_CONFIG_FOR_DOC = "SEWDConfig"
41-
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k"
4246
_PROCESSOR_FOR_DOC = "Wav2Vec2Processor"
43-
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
4447

45-
_SEQ_CLASS_CHECKPOINT = "asapp/sew-d-tiny-100k"
48+
# Base docstring
49+
_CHECKPOINT_FOR_DOC = "asapp/sew-d-tiny-100k-ft-ls100h"
50+
_EXPECTED_OUTPUT_SHAPE = [1, 292, 384]
4651

47-
_HIDDEN_STATES_START_POSITION = 1
52+
# CTC docstring
53+
_CTC_EXPECTED_OUTPUT = "'MISTER QUILTER IS THE APOSTIL OF THE MIDDLE CLASSES AND WE ARE GLAD TO WELCOME HIS GOSPEL'"
54+
_CTC_EXPECTED_LOSS = 0.21
55+
56+
# Audio class docstring
57+
_FEAT_EXTRACTOR_FOR_DOC = "Wav2Vec2FeatureExtractor"
58+
_SEQ_CLASS_CHECKPOINT = "anton-l/sew-d-mid-400k-ft-keyword-spotting"
59+
_SEQ_CLASS_EXPECTED_OUTPUT = "'_unknown_'"
60+
_SEQ_CLASS_EXPECTED_LOSS = 3.16
4861

4962
SEW_D_PRETRAINED_MODEL_ARCHIVE_LIST = [
5063
"asapp/sew-d-tiny-100k",
@@ -1415,6 +1428,7 @@ def _mask_hidden_states(
14151428
output_type=BaseModelOutput,
14161429
config_class=_CONFIG_FOR_DOC,
14171430
modality="audio",
1431+
expected_output=_EXPECTED_OUTPUT_SHAPE,
14181432
)
14191433
def forward(
14201434
self,
@@ -1514,6 +1528,8 @@ def freeze_feature_encoder(self):
15141528
checkpoint=_CHECKPOINT_FOR_DOC,
15151529
output_type=CausalLMOutput,
15161530
config_class=_CONFIG_FOR_DOC,
1531+
expected_output=_CTC_EXPECTED_OUTPUT,
1532+
expected_loss=_CTC_EXPECTED_LOSS,
15171533
)
15181534
def forward(
15191535
self,
@@ -1644,6 +1660,8 @@ def freeze_base_model(self):
16441660
output_type=SequenceClassifierOutput,
16451661
config_class=_CONFIG_FOR_DOC,
16461662
modality="audio",
1663+
expected_output=_SEQ_CLASS_EXPECTED_OUTPUT,
1664+
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
16471665
)
16481666
def forward(
16491667
self,

0 commit comments

Comments
 (0)