28
28
class BartSummarizationDistiller (SummarizationModule ):
29
29
"""Supports Bart, Pegasus and other models that inherit from Bart."""
30
30
31
- loss_names = ["loss" , "ce_loss" , "mlm_loss" , "enc_mse_loss" , " hid_loss_enc" , "hid_loss_dec" ]
31
+ loss_names = ["loss" , "ce_loss" , "mlm_loss" , "hid_loss_enc" , "hid_loss_dec" ]
32
32
33
33
def __init__ (self , hparams ):
34
34
assert Path (hparams .data_dir ).exists ()
@@ -46,9 +46,19 @@ def __init__(self, hparams):
46
46
if hparams .length_penalty != - 1 :
47
47
student .config .length_penalty = hparams .length_penalty
48
48
super ().__init__ (hparams , model = student , config = student .config )
49
+ model_type = student .config .model_type
49
50
self .e_layer_ids , self .d_layer_ids = e_layer_ids , d_layer_ids # type: List[int], List[int]
50
- self .different_encoder = hparams .student_encoder_layers != teacher .config .encoder_layers
51
- self .different_decoder = hparams .student_decoder_layers != teacher .config .decoder_layers
51
+
52
+ if model_type == "t5" :
53
+ teacher_encoder_layers = len (teacher .get_encoder ().block )
54
+ teacher_decoder_layers = len (teacher .get_decoder ().block )
55
+ else :
56
+ teacher_encoder_layers = teacher .config .encoder_layers
57
+ teacher_decoder_layers = teacher .config .decoder_layers
58
+
59
+ self .different_encoder = hparams .student_encoder_layers != teacher_encoder_layers
60
+ self .different_decoder = hparams .student_decoder_layers != teacher_decoder_layers
61
+
52
62
self .teacher = teacher
53
63
freeze_params (self .teacher )
54
64
@@ -59,17 +69,17 @@ def __init__(self, hparams):
59
69
del self .teacher .encoder
60
70
# Intermediate supervision: Decide which layers to supervise
61
71
if hparams .supervise_forward :
62
- self .d_matches = get_layers_to_supervise (
63
- n_student = len (self .d_layer_ids ), n_teacher = self . teacher . config . decoder_layers
64
- )
65
- else :
72
+ self .e_matches = get_layers_to_supervise (n_student = len ( self . e_layer_ids ), n_teacher = teacher_encoder_layers )
73
+ self . d_matches = get_layers_to_supervise ( n_student = len (self .d_layer_ids ), n_teacher = teacher_decoder_layers )
74
+ else : # student layer should emulate hidden states of the teacher layer it was copied from
75
+ self . e_matches = self . e_layer_ids
66
76
self .d_matches = self .d_layer_ids
77
+
67
78
self .ce_loss_fct = nn .KLDivLoss (reduction = "batchmean" )
68
79
self .temperature = 2.0
69
80
self .alpha_mlm = hparams .alpha_mlm
70
81
self .alpha_ce = hparams .alpha_ce
71
82
self .alpha_hid = hparams .alpha_hid
72
- self .alpha_encoder_loss = hparams .alpha_encoder_loss
73
83
gc .collect ()
74
84
torch .cuda .empty_cache ()
75
85
@@ -129,7 +139,7 @@ def _step(self, batch):
129
139
output_hidden_states = True ,
130
140
output_attentions = False ,
131
141
use_cache = False ,
132
- ) # TODO(@sshleifer): return_dict=True cleanup
142
+ )
133
143
134
144
# Same cross entropy vs. label smoothing logic as finetune.py
135
145
assert lm_logits .shape [- 1 ] == self .model .config .vocab_size
@@ -146,30 +156,32 @@ def _step(self, batch):
146
156
def zero_tensor ():
147
157
return torch .tensor (0.0 ).type_as (student_lm_loss )
148
158
149
- loss_encoder , hid_loss_enc , hid_loss_dec = zero_tensor (), zero_tensor (), zero_tensor ()
150
- if self .different_encoder :
159
+ hid_loss_enc , hid_loss_dec = zero_tensor (), zero_tensor ()
160
+ if self .different_encoder : # compute encoder hidden state loss
151
161
with torch .no_grad ():
152
- teacher_enc_outputs , teacher_enc_hid , _ = self .teacher .get_encoder ()(
153
- input_ids , attention_mask = src_mask , output_hidden_states = True
154
- )
155
- # DEPRECATE THIS
156
- if self .hparams . alpha_encoder_loss > 0 :
157
- loss_encoder = self . calc_mse_loss ( enc_outputs , teacher_enc_outputs , src_mask )
158
-
159
- hid_loss_enc = self . calc_hidden_loss ( src_mask , enc_hidden_state , teacher_enc_hid , self . e_layer_ids )
160
-
161
- teacher_enc_outputs = ( enc_outputs ,)
162
- assert isinstance ( teacher_enc_outputs , tuple ), type ( teacher_enc_outputs )
162
+ teacher_enc_hid = self .teacher .get_encoder ()(
163
+ input_ids , attention_mask = src_mask , output_hidden_states = True , return_dict = True
164
+ ). hidden_states
165
+
166
+ hid_loss_enc = self .calc_hidden_loss (
167
+ src_mask ,
168
+ enc_hidden_state ,
169
+ teacher_enc_hid ,
170
+ self . e_matches ,
171
+ normalize_hidden = self . hparams . normalize_hidden ,
172
+ )
163
173
164
174
with torch .no_grad ():
165
- tloss , tlogits , tdec_hidden , _ = self .teacher (
175
+ outputs = self .teacher (
166
176
input_ids ,
167
177
attention_mask = src_mask ,
168
- encoder_outputs = teacher_enc_outputs ,
178
+ encoder_outputs = ( enc_outputs ,) ,
169
179
decoder_input_ids = decoder_input_ids ,
170
180
lm_labels = labels ,
171
181
output_hidden_states = True ,
182
+ return_dict = True ,
172
183
)
184
+ tlogits , tdec_hidden = outputs .logits , outputs .decoder_hidden_states
173
185
dec_mask = decoder_input_ids .ne (pad_token_id )
174
186
loss_ce = self .calc_ce_loss (dec_mask , lm_logits , tlogits )
175
187
if self .alpha_hid > 0 : # Intermediate supervision of decoder hidden states
@@ -180,10 +192,9 @@ def zero_tensor():
180
192
blended_loss = (
181
193
self .alpha_ce * loss_ce
182
194
+ self .alpha_mlm * student_lm_loss
183
- + self .hparams .alpha_encoder_loss * loss_encoder
184
195
+ self .hparams .alpha_hid * (hid_loss_enc + hid_loss_dec )
185
196
)
186
- return blended_loss , loss_ce , student_lm_loss , loss_encoder , hid_loss_enc , hid_loss_dec
197
+ return blended_loss , loss_ce , student_lm_loss , hid_loss_enc , hid_loss_dec
187
198
188
199
@staticmethod
189
200
def calc_hidden_loss (attention_mask , hidden_states , hidden_states_T , matches , normalize_hidden ):
@@ -207,7 +218,6 @@ def add_distill_args(parser):
207
218
parser .add_argument ("--teacher" , type = str )
208
219
parser .add_argument ("--alpha_ce" , default = 0.8 , type = float )
209
220
parser .add_argument ("--alpha_mlm" , default = 0.2 , type = float )
210
- parser .add_argument ("--alpha_encoder_loss" , default = 0.0 , type = float )
211
221
parser .add_argument ("--alpha_hid" , default = 0.0 , type = float , required = False )
212
222
parser .add_argument ("--student_decoder_layers" , default = 12 , type = int , required = False )
213
223
parser .add_argument ("--student_encoder_layers" , default = 12 , type = int , required = False )
0 commit comments