@@ -113,6 +113,10 @@ def __init__(self, hparams, **kwargs):
113
113
self .eval_max_length = self .hparams .eval_max_gen_length
114
114
else :
115
115
self .eval_max_length = self .model .config .max_length
116
+ if self .hparams .eval_min_gen_length is not None :
117
+ self .eval_min_length = self .hparams .eval_min_gen_length
118
+ else :
119
+ self .eval_min_length = self .model .config .min_length
116
120
self .val_metric = self .default_val_metric if self .hparams .val_metric is None else self .hparams .val_metric
117
121
118
122
def save_readable_batch (self , batch : Dict [str , torch .Tensor ]) -> Dict [str , List [str ]]:
@@ -219,6 +223,7 @@ def _generative_step(self, batch: dict) -> dict:
219
223
decoder_start_token_id = self .decoder_start_token_id ,
220
224
num_beams = self .eval_beams ,
221
225
max_length = self .eval_max_length ,
226
+ min_length = self .eval_min_length ,
222
227
)
223
228
gen_time = (time .time () - t0 ) / batch ["input_ids" ].shape [0 ]
224
229
preds : List [str ] = self .ids_to_clean_text (generated_ids )
@@ -346,6 +351,7 @@ def add_model_specific_args(parser, root_dir):
346
351
"--val_metric" , type = str , default = None , required = False , choices = ["bleu" , "rouge2" , "loss" , None ]
347
352
)
348
353
parser .add_argument ("--eval_max_gen_length" , type = int , default = None , help = "never generate more than n tokens" )
354
+ parser .add_argument ("--eval_min_gen_length" , type = int , default = None , help = "never generate shorter than n tokens" )
349
355
parser .add_argument ("--save_top_k" , type = int , default = 1 , required = False , help = "How many checkpoints to save" )
350
356
parser .add_argument (
351
357
"--early_stopping_patience" ,
0 commit comments