|
20 | 20 | import torch
|
21 | 21 | import torch.utils.checkpoint
|
22 | 22 | from torch import nn
|
| 23 | +from torch.nn import CrossEntropyLoss |
23 | 24 |
|
24 | 25 | from transformers.deepspeed import is_deepspeed_zero3_enabled
|
25 | 26 |
|
26 | 27 | from ...activations import ACT2FN
|
27 | 28 | from ...file_utils import add_start_docstrings, add_start_docstrings_to_model_forward, replace_return_docstrings
|
28 |
| -from ...modeling_outputs import BaseModelOutput, CausalLMOutput |
| 29 | +from ...modeling_outputs import BaseModelOutput, CausalLMOutput, SequenceClassifierOutput |
29 | 30 | from ...modeling_utils import PreTrainedModel
|
30 | 31 | from ...utils import logging
|
31 | 32 | from .configuration_hubert import HubertConfig
|
@@ -735,6 +736,18 @@ def _conv_out_length(input_length, kernel_size, stride):
|
735 | 736 |
|
736 | 737 | return input_lengths
|
737 | 738 |
|
| 739 | + def _get_feature_vector_attention_mask(self, feature_vector_length: int, attention_mask: torch.LongTensor): |
| 740 | + output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) |
| 741 | + batch_size = attention_mask.shape[0] |
| 742 | + |
| 743 | + attention_mask = torch.zeros( |
| 744 | + (batch_size, feature_vector_length), dtype=attention_mask.dtype, device=attention_mask.device |
| 745 | + ) |
| 746 | + # these two operations makes sure that all values before the output lengths idxs are attended to |
| 747 | + attention_mask[(torch.arange(attention_mask.shape[0], device=attention_mask.device), output_lengths - 1)] = 1 |
| 748 | + attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() |
| 749 | + return attention_mask |
| 750 | + |
738 | 751 |
|
739 | 752 | HUBERT_START_DOCSTRING = r"""
|
740 | 753 | Hubert was proposed in `HuBERT: Self-Supervised Speech Representation Learning by Masked Prediction of Hidden Units
|
@@ -904,19 +917,8 @@ def forward(
|
904 | 917 | extract_features = extract_features.transpose(1, 2)
|
905 | 918 |
|
906 | 919 | if attention_mask is not None:
|
907 |
| - # compute real output lengths according to convolution formula |
908 |
| - output_lengths = self._get_feat_extract_output_lengths(attention_mask.sum(-1)).to(torch.long) |
909 |
| - |
910 |
| - attention_mask = torch.zeros( |
911 |
| - extract_features.shape[:2], dtype=extract_features.dtype, device=extract_features.device |
912 |
| - ) |
913 |
| - |
914 |
| - # these two operations makes sure that all values |
915 |
| - # before the output lengths indices are attended to |
916 |
| - attention_mask[ |
917 |
| - (torch.arange(attention_mask.shape[0], device=extract_features.device), output_lengths - 1) |
918 |
| - ] = 1 |
919 |
| - attention_mask = attention_mask.flip([-1]).cumsum(-1).flip([-1]).bool() |
| 920 | + # compute reduced attention_mask corresponding to feature vectors |
| 921 | + attention_mask = self._get_feature_vector_attention_mask(extract_features.shape[1], attention_mask) |
920 | 922 |
|
921 | 923 | hidden_states = self.feature_projection(extract_features)
|
922 | 924 | hidden_states = self._mask_hidden_states(hidden_states, mask_time_indices=mask_time_indices)
|
@@ -1070,3 +1072,128 @@ def forward(
|
1070 | 1072 | return CausalLMOutput(
|
1071 | 1073 | loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
|
1072 | 1074 | )
|
| 1075 | + |
| 1076 | + |
| 1077 | +@add_start_docstrings( |
| 1078 | + """ |
| 1079 | + Hubert Model with a sequence classification head on top (a linear layer over the pooled output) for tasks like |
| 1080 | + SUPERB Keyword Spotting. |
| 1081 | + """, |
| 1082 | + HUBERT_START_DOCSTRING, |
| 1083 | +) |
| 1084 | +class HubertForSequenceClassification(HubertPreTrainedModel): |
| 1085 | + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.__init__ with Wav2Vec2->Hubert, wav2vec2->hubert |
| 1086 | + def __init__(self, config): |
| 1087 | + super().__init__(config) |
| 1088 | + |
| 1089 | + self.hubert = HubertModel(config) |
| 1090 | + num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings |
| 1091 | + if config.use_weighted_layer_sum: |
| 1092 | + self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers) |
| 1093 | + self.projector = nn.Linear(config.hidden_size, config.classifier_proj_size) |
| 1094 | + self.classifier = nn.Linear(config.classifier_proj_size, config.num_labels) |
| 1095 | + |
| 1096 | + self.init_weights() |
| 1097 | + |
| 1098 | + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_feature_extractor with wav2vec2->hubert |
| 1099 | + def freeze_feature_extractor(self): |
| 1100 | + """ |
| 1101 | + Calling this function will disable the gradient computation for the feature extractor so that its parameters |
| 1102 | + will not be updated during training. |
| 1103 | + """ |
| 1104 | + self.hubert.feature_extractor._freeze_parameters() |
| 1105 | + |
| 1106 | + # Copied from transformers.models.wav2vec2.modeling_wav2vec2.Wav2Vec2ForSequenceClassification.freeze_base_model with wav2vec2->hubert |
| 1107 | + def freeze_base_model(self): |
| 1108 | + """ |
| 1109 | + Calling this function will disable the gradient computation for the base model so that its parameters will not |
| 1110 | + be updated during training. Only the classification head will be updated. |
| 1111 | + """ |
| 1112 | + for param in self.hubert.parameters(): |
| 1113 | + param.requires_grad = False |
| 1114 | + |
| 1115 | + @add_start_docstrings_to_model_forward(HUBERT_INPUTS_DOCSTRING) |
| 1116 | + @replace_return_docstrings(output_type=SequenceClassifierOutput, config_class=_CONFIG_FOR_DOC) |
| 1117 | + def forward( |
| 1118 | + self, |
| 1119 | + input_values, |
| 1120 | + attention_mask=None, |
| 1121 | + output_attentions=None, |
| 1122 | + output_hidden_states=None, |
| 1123 | + return_dict=None, |
| 1124 | + labels=None, |
| 1125 | + ): |
| 1126 | + r""" |
| 1127 | + labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): |
| 1128 | + Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ..., |
| 1129 | + config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss), |
| 1130 | + If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
| 1131 | +
|
| 1132 | + Returns: |
| 1133 | +
|
| 1134 | + Example:: |
| 1135 | +
|
| 1136 | + >>> import torch |
| 1137 | + >>> from transformers import Wav2Vec2FeatureExtractor, HubertForSequenceClassification |
| 1138 | + >>> from datasets import load_dataset |
| 1139 | +
|
| 1140 | + >>> processor = Wav2Vec2FeatureExtractor.from_pretrained("superb/hubert-base-superb-ks") |
| 1141 | + >>> model = HubertForSequenceClassification.from_pretrained("superb/hubert-base-superb-ks") |
| 1142 | +
|
| 1143 | + >>> ds = load_dataset("anton-l/superb_dummy", "ks", split="test") |
| 1144 | +
|
| 1145 | + >>> input_values = processor(ds["speech"][4], return_tensors="pt").input_values # Batch size 1 |
| 1146 | + >>> logits = model(input_values).logits |
| 1147 | + >>> predicted_class_ids = torch.argmax(logits, dim=-1) |
| 1148 | +
|
| 1149 | + >>> # compute loss |
| 1150 | + >>> target_label = "down" |
| 1151 | + >>> labels = torch.tensor([model.config.label2id[target_label]]) |
| 1152 | +
|
| 1153 | + >>> loss = model(input_values, labels=labels).loss |
| 1154 | + """ |
| 1155 | + |
| 1156 | + return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| 1157 | + output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states |
| 1158 | + |
| 1159 | + outputs = self.hubert( |
| 1160 | + input_values, |
| 1161 | + attention_mask=attention_mask, |
| 1162 | + output_attentions=output_attentions, |
| 1163 | + output_hidden_states=output_hidden_states, |
| 1164 | + return_dict=return_dict, |
| 1165 | + ) |
| 1166 | + |
| 1167 | + if self.config.use_weighted_layer_sum: |
| 1168 | + hidden_states = outputs[1] |
| 1169 | + hidden_states = torch.stack(hidden_states, dim=1) |
| 1170 | + norm_weights = nn.functional.softmax(self.layer_weights, dim=-1) |
| 1171 | + hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1) |
| 1172 | + else: |
| 1173 | + hidden_states = outputs[0] |
| 1174 | + |
| 1175 | + hidden_states = self.projector(hidden_states) |
| 1176 | + if attention_mask is None: |
| 1177 | + pooled_output = hidden_states.mean(dim=1) |
| 1178 | + else: |
| 1179 | + padding_mask = self._get_feature_vector_attention_mask(hidden_states.shape[1], attention_mask) |
| 1180 | + hidden_states[~padding_mask] = 0.0 |
| 1181 | + pooled_output = hidden_states.sum(dim=1) / padding_mask.sum(dim=1).view(-1, 1) |
| 1182 | + |
| 1183 | + logits = self.classifier(pooled_output) |
| 1184 | + |
| 1185 | + loss = None |
| 1186 | + if labels is not None: |
| 1187 | + loss_fct = CrossEntropyLoss() |
| 1188 | + loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1)) |
| 1189 | + |
| 1190 | + if not return_dict: |
| 1191 | + output = (logits,) + outputs[1:] |
| 1192 | + return ((loss,) + output) if loss is not None else output |
| 1193 | + |
| 1194 | + return SequenceClassifierOutput( |
| 1195 | + loss=loss, |
| 1196 | + logits=logits, |
| 1197 | + hidden_states=outputs.hidden_states, |
| 1198 | + attentions=outputs.attentions, |
| 1199 | + ) |
0 commit comments