Skip to content

Commit b6f332e

Browse files
authored
Add Wav2Vec2 & Hubert ForSequenceClassification (huggingface#13153)
* Add hubert classifier + tests * Add hubert classifier + tests * Dummies for all classification tests * Wav2Vec2 classifier + ER test * Fix hubert integration tests * Add hubert IC * Pass tests for all classification tasks on Hubert * Pass all tests + copies * Move models to the SUPERB org
1 parent 2bef343 commit b6f332e

16 files changed

+823
-36
lines changed

docs/source/model_doc/hubert.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,14 @@ HubertForCTC
6464
.. autoclass:: transformers.HubertForCTC
6565
:members: forward
6666

67+
68+
HubertForSequenceClassification
69+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
70+
71+
.. autoclass:: transformers.HubertForSequenceClassification
72+
:members: forward
73+
74+
6775
TFHubertModel
6876
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
6977

docs/source/model_doc/wav2vec2.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,14 @@ Wav2Vec2ForCTC
9696
.. autoclass:: transformers.Wav2Vec2ForCTC
9797
:members: forward
9898

99+
100+
Wav2Vec2ForSequenceClassification
101+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
102+
103+
.. autoclass:: transformers.Wav2Vec2ForSequenceClassification
104+
:members: forward
105+
106+
99107
Wav2Vec2ForPreTraining
100108
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
101109

src/transformers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,7 @@
818818
[
819819
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
820820
"HubertForCTC",
821+
"HubertForSequenceClassification",
821822
"HubertModel",
822823
"HubertPreTrainedModel",
823824
]
@@ -1128,6 +1129,7 @@
11281129
"Wav2Vec2ForCTC",
11291130
"Wav2Vec2ForMaskedLM",
11301131
"Wav2Vec2ForPreTraining",
1132+
"Wav2Vec2ForSequenceClassification",
11311133
"Wav2Vec2Model",
11321134
"Wav2Vec2PreTrainedModel",
11331135
]
@@ -2424,6 +2426,7 @@
24242426
from .models.hubert import (
24252427
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
24262428
HubertForCTC,
2429+
HubertForSequenceClassification,
24272430
HubertModel,
24282431
HubertPreTrainedModel,
24292432
)
@@ -2681,6 +2684,7 @@
26812684
Wav2Vec2ForCTC,
26822685
Wav2Vec2ForMaskedLM,
26832686
Wav2Vec2ForPreTraining,
2687+
Wav2Vec2ForSequenceClassification,
26842688
Wav2Vec2Model,
26852689
Wav2Vec2PreTrainedModel,
26862690
)

src/transformers/models/hubert/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
_import_structure["modeling_hubert"] = [
2929
"HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST",
3030
"HubertForCTC",
31+
"HubertForSequenceClassification",
3132
"HubertModel",
3233
"HubertPreTrainedModel",
3334
]
@@ -48,6 +49,7 @@
4849
from .modeling_hubert import (
4950
HUBERT_PRETRAINED_MODEL_ARCHIVE_LIST,
5051
HubertForCTC,
52+
HubertForSequenceClassification,
5153
HubertModel,
5254
HubertPreTrainedModel,
5355
)

src/transformers/models/hubert/configuration_hubert.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,11 @@ class HubertConfig(PretrainedConfig):
115115
Whether to zero infinite losses and the associated gradients of ``torch.nn.CTCLoss``. Infinite losses
116116
mainly occur when the inputs are too short to be aligned to the targets. Only relevant when training an
117117
instance of :class:`~transformers.HubertForCTC`.
118+
use_weighted_layer_sum (:obj:`bool`, `optional`, defaults to :obj:`False`):
119+
Whether to use a weighted average of layer outputs with learned weights. Only relevant when using an
120+
instance of :class:`~transformers.HubertForSequenceClassification`.
121+
classifier_proj_size (:obj:`int`, `optional`, defaults to 256):
122+
Dimensionality of the projection before token mean-pooling for classification.
118123
gradient_checkpointing (:obj:`bool`, `optional`, defaults to :obj:`False`):
119124
If True, use gradient checkpointing to save memory at the expense of slower backward pass.
120125
@@ -165,6 +170,8 @@ def __init__(
165170
mask_feature_length=10,
166171
ctc_loss_reduction="sum",
167172
ctc_zero_infinity=False,
173+
use_weighted_layer_sum=False,
174+
classifier_proj_size=256,
168175
gradient_checkpointing=False,
169176
pad_token_id=0,
170177
bos_token_id=1,
@@ -197,6 +204,8 @@ def __init__(
197204
self.vocab_size = vocab_size
198205
self.do_stable_layer_norm = do_stable_layer_norm
199206
self.gradient_checkpointing = gradient_checkpointing
207+
self.use_weighted_layer_sum = use_weighted_layer_sum
208+
self.classifier_proj_size = classifier_proj_size
200209

201210
if (
202211
(len(self.conv_stride) != self.num_feat_extract_layers)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# coding=utf-8
2+
# Copyright 2021 The HuggingFace Inc. team.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""Convert Hubert checkpoint."""
16+
17+
18+
import argparse
19+
20+
import torch
21+
22+
from transformers import HubertConfig, HubertForSequenceClassification, Wav2Vec2FeatureExtractor, logging
23+
24+
25+
logging.set_verbosity_info()
26+
logger = logging.get_logger(__name__)
27+
28+
SUPPORTED_MODELS = ["UtteranceLevel"]
29+
30+
31+
@torch.no_grad()
32+
def convert_s3prl_checkpoint(base_model_name, config_path, checkpoint_path, model_dump_path):
33+
"""
34+
Copy/paste/tweak model's weights to transformers design.
35+
"""
36+
checkpoint = torch.load(checkpoint_path, map_location="cpu")
37+
if checkpoint["Config"]["downstream_expert"]["modelrc"]["select"] not in SUPPORTED_MODELS:
38+
raise NotImplementedError(f"The supported s3prl models are {SUPPORTED_MODELS}")
39+
40+
downstream_dict = checkpoint["Downstream"]
41+
42+
hf_congfig = HubertConfig.from_pretrained(config_path)
43+
hf_model = HubertForSequenceClassification.from_pretrained(base_model_name, config=hf_congfig)
44+
hf_feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
45+
base_model_name, return_attention_mask=True, do_normalize=False
46+
)
47+
48+
if hf_congfig.use_weighted_layer_sum:
49+
hf_model.layer_weights.data = checkpoint["Featurizer"]["weights"]
50+
51+
hf_model.projector.weight.data = downstream_dict["projector.weight"]
52+
hf_model.projector.bias.data = downstream_dict["projector.bias"]
53+
hf_model.classifier.weight.data = downstream_dict["model.post_net.linear.weight"]
54+
hf_model.classifier.bias.data = downstream_dict["model.post_net.linear.bias"]
55+
56+
hf_feature_extractor.save_pretrained(model_dump_path)
57+
hf_model.save_pretrained(model_dump_path)
58+
59+
60+
if __name__ == "__main__":
61+
parser = argparse.ArgumentParser()
62+
parser.add_argument(
63+
"--base_model_name", default=None, type=str, help="Name of the huggingface pretrained base model."
64+
)
65+
parser.add_argument("--config_path", default=None, type=str, help="Path to the huggingface classifier config.")
66+
parser.add_argument("--checkpoint_path", default=None, type=str, help="Path to the s3prl checkpoint.")
67+
parser.add_argument("--model_dump_path", default=None, type=str, help="Path to the final converted model.")
68+
args = parser.parse_args()
69+
convert_s3prl_checkpoint(args.base_model_name, args.config_path, args.checkpoint_path, args.model_dump_path)

src/transformers/models/hubert/modeling_hubert.py

Lines changed: 141 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,13 @@
2020
import torch
2121
import torch.utils.checkpoint
2222
from torch import nn
23+
from torch.nn import CrossEntropyLoss
2324

2425
from transformers.deepspeed import is_deepspeed_zero3_enabled
2526

2627
from ...activations import ACT2FN
2728
from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
28-
from ...modeling_outputs import BaseModelOutput, CausalLMOutput
29+
from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput
2930
from ...modeling_utils import PreTrainedModel
3031
from ...utils import logging
3132
from .configuration_hubert import HubertConfig
@@ -735,6 +736,18 @@ def _conv_out_length(input_length, kernel_size, stride):
735736

736737
return input_lengths
737738

739+
def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor):
740+
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
741+
batch_size = attention_mask.shape[0]
742+
743+
attention_mask = torch.zeros(
744+
(batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device
745+
)
746+
# these two operations makes sure that all values before the output lengths idxs are attended to
747+
attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1
748+
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
749+
return attention_mask
750+
738751

739752
HUBERT_START_DOCSTRING = r"""
740753
Hubert was proposed in `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units
@@ -904,19 +917,8 @@ def forward(
904917
extract_features = extract_features.transpose(1, 2)
905918

906919
if attention_mask is not None:
907-
# compute real output lengths according to convolution formula
908-
output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long)
909-
910-
attention_mask = torch.zeros(
911-
extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device
912-
)
913-
914-
# these two operations makes sure that all values
915-
# before the output lengths indices are attended to
916-
attention_mask[
917-
(torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1)
918-
] = 1
919-
attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool()
920+
# compute reduced attention_mask corresponding to feature vectors
921+
attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask)
920922

921923
hidden_states = self.feature_projection(extract_features)
922924
hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
@@ -1070,3 +1072,128 @@ def forward(
10701072
return CausalLMOutput(
10711073
loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
10721074
)
1075+
1076+
1077+
@add_start_docstrings(
1078+
"""
1079+
Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like
1080+
SUPERB Keyword Spotting.
1081+
""",
1082+
HUBERT_START_DOCSTRING,
1083+
)
1084+
class HubertForSequenceClassification(HubertPreTrainedModel):
1085+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Hubert, wav2vec2->hubert
1086+
def __init__(self, config):
1087+
super().__init__(config)
1088+
1089+
self.hubert = HubertModel(config)
1090+
num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
1091+
if config.use_weighted_layer_sum:
1092+
self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
1093+
self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size)
1094+
self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels)
1095+
1096+
self.init_weights()
1097+
1098+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor with wav2vec2->hubert
1099+
def freeze_feature_extractor(self):
1100+
"""
1101+
Calling this function will disable the gradient computation for the feature extractor so that its parameters
1102+
will not be updated during training.
1103+
"""
1104+
self.hubert.feature_extractor._freeze_parameters()
1105+
1106+
# Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->hubert
1107+
def freeze_base_model(self):
1108+
"""
1109+
Calling this function will disable the gradient computation for the base model so that its parameters will not
1110+
be updated during training. Only the classification head will be updated.
1111+
"""
1112+
for param in self.hubert.parameters():
1113+
param.requires_grad = False
1114+
1115+
@add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING)
1116+
@replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC)
1117+
def forward(
1118+
self,
1119+
input_values,
1120+
attention_mask=None,
1121+
output_attentions=None,
1122+
output_hidden_states=None,
1123+
return_dict=None,
1124+
labels=None,
1125+
):
1126+
r"""
1127+
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
1128+
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
1129+
config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
1130+
If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
1131+
1132+
Returns:
1133+
1134+
Example::
1135+
1136+
>>> import torch
1137+
>>> from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification
1138+
>>> from datasets import load_dataset
1139+
1140+
>>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks")
1141+
>>> model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks")
1142+
1143+
>>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test")
1144+
1145+
>>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1
1146+
>>> logits = model(input_values).logits
1147+
>>> predicted_class_ids = torch.argmax(logits, dim=-1)
1148+
1149+
>>> # compute loss
1150+
>>> target_label = "down"
1151+
>>> labels = torch.tensor([model.config.label2id[target_label]])
1152+
1153+
>>> loss = model(input_values, labels=labels).loss
1154+
"""
1155+
1156+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1157+
output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
1158+
1159+
outputs = self.hubert(
1160+
input_values,
1161+
attention_mask=attention_mask,
1162+
output_attentions=output_attentions,
1163+
output_hidden_states=output_hidden_states,
1164+
return_dict=return_dict,
1165+
)
1166+
1167+
if self.config.use_weighted_layer_sum:
1168+
hidden_states = outputs[1]
1169+
hidden_states = torch.stack(hidden_states, dim=1)
1170+
norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
1171+
hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
1172+
else:
1173+
hidden_states = outputs[0]
1174+
1175+
hidden_states = self.projector(hidden_states)
1176+
if attention_mask is None:
1177+
pooled_output = hidden_states.mean(dim=1)
1178+
else:
1179+
padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask)
1180+
hidden_states[~padding_mask] = 0.0
1181+
pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1)
1182+
1183+
logits = self.classifier(pooled_output)
1184+
1185+
loss = None
1186+
if labels is not None:
1187+
loss_fct = CrossEntropyLoss()
1188+
loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
1189+
1190+
if not return_dict:
1191+
output = (logits,) + outputs[1:]
1192+
return ((loss,) + output) if loss is not None else output
1193+
1194+
return SequenceClassifierOutput(
1195+
loss=loss,
1196+
logits=logits,
1197+
hidden_states=outputs.hidden_states,
1198+
attentions=outputs.attentions,
1199+
)

src/transformers/models/wav2vec2/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
"Wav2Vec2ForCTC",
3434
"Wav2Vec2ForMaskedLM",
3535
"Wav2Vec2ForPreTraining",
36+
"Wav2Vec2ForSequenceClassification",
3637
"Wav2Vec2Model",
3738
"Wav2Vec2PreTrainedModel",
3839
]
@@ -66,6 +67,7 @@
6667
Wav2Vec2ForCTC,
6768
Wav2Vec2ForMaskedLM,
6869
Wav2Vec2ForPreTraining,
70+
Wav2Vec2ForSequenceClassification,
6971
Wav2Vec2Model,
7072
Wav2Vec2PreTrainedModel,
7173
)

0 commit comments

Comments
 (0)