Skip to content

Commit d5d2744

Browse files
authored
Support T5 Distillation w/hidden state supervision (huggingface#7599)
1 parent 818c294 commit d5d2744

File tree

2 files changed

+37
-30
lines changed

2 files changed

+37
-30
lines changed

examples/seq2seq/distillation.py

Lines changed: 37 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
class BartSummarizationDistiller(SummarizationModule):
2929
"""Supports Bart, Pegasus and other models that inherit from Bart."""
3030

31-
loss_names = ["loss", "ce_loss", "mlm_loss", "enc_mse_loss", "hid_loss_enc", "hid_loss_dec"]
31+
loss_names = ["loss", "ce_loss", "mlm_loss", "hid_loss_enc", "hid_loss_dec"]
3232

3333
def __init__(self, hparams):
3434
assert Path(hparams.data_dir).exists()
@@ -46,9 +46,19 @@ def __init__(self, hparams):
4646
if hparams.length_penalty != -1:
4747
student.config.length_penalty = hparams.length_penalty
4848
super().__init__(hparams, model=student, config=student.config)
49+
model_type = student.config.model_type
4950
self.e_layer_ids, self.d_layer_ids = e_layer_ids, d_layer_ids # type: List[int], List[int]
50-
self.different_encoder = hparams.student_encoder_layers != teacher.config.encoder_layers
51-
self.different_decoder = hparams.student_decoder_layers != teacher.config.decoder_layers
51+
52+
if model_type == "t5":
53+
teacher_encoder_layers = len(teacher.get_encoder().block)
54+
teacher_decoder_layers = len(teacher.get_decoder().block)
55+
else:
56+
teacher_encoder_layers = teacher.config.encoder_layers
57+
teacher_decoder_layers = teacher.config.decoder_layers
58+
59+
self.different_encoder = hparams.student_encoder_layers != teacher_encoder_layers
60+
self.different_decoder = hparams.student_decoder_layers != teacher_decoder_layers
61+
5262
self.teacher = teacher
5363
freeze_params(self.teacher)
5464

@@ -59,17 +69,17 @@ def __init__(self, hparams):
5969
del self.teacher.encoder
6070
# Intermediate supervision: Decide which layers to supervise
6171
if hparams.supervise_forward:
62-
self.d_matches = get_layers_to_supervise(
63-
n_student=len(self.d_layer_ids), n_teacher=self.teacher.config.decoder_layers
64-
)
65-
else:
72+
self.e_matches = get_layers_to_supervise(n_student=len(self.e_layer_ids), n_teacher=teacher_encoder_layers)
73+
self.d_matches = get_layers_to_supervise(n_student=len(self.d_layer_ids), n_teacher=teacher_decoder_layers)
74+
else: # student layer should emulate hidden states of the teacher layer it was copied from
75+
self.e_matches = self.e_layer_ids
6676
self.d_matches = self.d_layer_ids
77+
6778
self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
6879
self.temperature = 2.0
6980
self.alpha_mlm = hparams.alpha_mlm
7081
self.alpha_ce = hparams.alpha_ce
7182
self.alpha_hid = hparams.alpha_hid
72-
self.alpha_encoder_loss = hparams.alpha_encoder_loss
7383
gc.collect()
7484
torch.cuda.empty_cache()
7585

@@ -129,7 +139,7 @@ def _step(self, batch):
129139
output_hidden_states=True,
130140
output_attentions=False,
131141
use_cache=False,
132-
) # TODO(@sshleifer): return_dict=True cleanup
142+
)
133143

134144
# Same cross entropy vs. label smoothing logic as finetune.py
135145
assert lm_logits.shape[-1] == self.model.config.vocab_size
@@ -146,30 +156,32 @@ def _step(self, batch):
146156
def zero_tensor():
147157
return torch.tensor(0.0).type_as(student_lm_loss)
148158

149-
loss_encoder, hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor(), zero_tensor()
150-
if self.different_encoder:
159+
hid_loss_enc, hid_loss_dec = zero_tensor(), zero_tensor()
160+
if self.different_encoder: # compute encoder hidden state loss
151161
with torch.no_grad():
152-
teacher_enc_outputs, teacher_enc_hid, _ = self.teacher.get_encoder()(
153-
input_ids, attention_mask=src_mask, output_hidden_states=True
154-
)
155-
# DEPRECATE THIS
156-
if self.hparams.alpha_encoder_loss > 0:
157-
loss_encoder = self.calc_mse_loss(enc_outputs, teacher_enc_outputs, src_mask)
158-
159-
hid_loss_enc = self.calc_hidden_loss(src_mask, enc_hidden_state, teacher_enc_hid, self.e_layer_ids)
160-
161-
teacher_enc_outputs = (enc_outputs,)
162-
assert isinstance(teacher_enc_outputs, tuple), type(teacher_enc_outputs)
162+
teacher_enc_hid = self.teacher.get_encoder()(
163+
input_ids, attention_mask=src_mask, output_hidden_states=True, return_dict=True
164+
).hidden_states
165+
166+
hid_loss_enc = self.calc_hidden_loss(
167+
src_mask,
168+
enc_hidden_state,
169+
teacher_enc_hid,
170+
self.e_matches,
171+
normalize_hidden=self.hparams.normalize_hidden,
172+
)
163173

164174
with torch.no_grad():
165-
tloss, tlogits, tdec_hidden, _ = self.teacher(
175+
outputs = self.teacher(
166176
input_ids,
167177
attention_mask=src_mask,
168-
encoder_outputs=teacher_enc_outputs,
178+
encoder_outputs=(enc_outputs,),
169179
decoder_input_ids=decoder_input_ids,
170180
lm_labels=labels,
171181
output_hidden_states=True,
182+
return_dict=True,
172183
)
184+
tlogits, tdec_hidden = outputs.logits, outputs.decoder_hidden_states
173185
dec_mask = decoder_input_ids.ne(pad_token_id)
174186
loss_ce = self.calc_ce_loss(dec_mask, lm_logits, tlogits)
175187
if self.alpha_hid > 0: # Intermediate supervision of decoder hidden states
@@ -180,10 +192,9 @@ def zero_tensor():
180192
blended_loss = (
181193
self.alpha_ce * loss_ce
182194
+ self.alpha_mlm * student_lm_loss
183-
+ self.hparams.alpha_encoder_loss * loss_encoder
184195
+ self.hparams.alpha_hid * (hid_loss_enc + hid_loss_dec)
185196
)
186-
return blended_loss, loss_ce, student_lm_loss, loss_encoder, hid_loss_enc, hid_loss_dec
197+
return blended_loss, loss_ce, student_lm_loss, hid_loss_enc, hid_loss_dec
187198

188199
@staticmethod
189200
def calc_hidden_loss(attention_mask, hidden_states, hidden_states_T, matches, normalize_hidden):
@@ -207,7 +218,6 @@ def add_distill_args(parser):
207218
parser.add_argument("--teacher", type=str)
208219
parser.add_argument("--alpha_ce", default=0.8, type=float)
209220
parser.add_argument("--alpha_mlm", default=0.2, type=float)
210-
parser.add_argument("--alpha_encoder_loss", default=0.0, type=float)
211221
parser.add_argument("--alpha_hid", default=0.0, type=float, required=False)
212222
parser.add_argument("--student_decoder_layers", default=12, type=int, required=False)
213223
parser.add_argument("--student_encoder_layers", default=12, type=int, required=False)

examples/seq2seq/test_seq2seq_examples.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@
8686
"n_val": -1,
8787
"n_test": -1,
8888
"student_encoder_layers": 1,
89-
"alpha_encoder_loss": 0.0,
9089
"freeze_encoder": False,
9190
"auto_scale_batch_size": False,
9291
}
@@ -230,7 +229,6 @@ def test_distill_mbart(self):
230229

231230
evaluate_checkpoint(ckpts[0], dest_dir=Path(tempfile.mkdtemp()))
232231

233-
@unittest.skip("T5 distillation is broken at the moment")
234232
def test_distill_t5(self):
235233
updates = dict(
236234
student_encoder_layers=1,
@@ -255,7 +253,6 @@ def _test_distiller_cli(self, updates, check_contents=True):
255253
model_name_or_path="sshleifer/tinier_bart",
256254
teacher=CHEAP_ARGS["model_name_or_path"],
257255
val_check_interval=0.5,
258-
alpha_encoder_loss=0.4,
259256
)
260257
default_updates.update(updates)
261258
args_d: dict = CHEAP_ARGS.copy()

0 commit comments

Comments
 (0)