Skip to content

Commit 5aa361f

Browse files
author
Daniel Khashabi
authored
finetune.py: specifying generation min_length (huggingface#8478)
1 parent 30e7f7e commit 5aa361f

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

examples/seq2seq/finetune.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,10 @@ def __init__(self, hparams, **kwargs):
113113
self.eval_max_length = self.hparams.eval_max_gen_length
114114
else:
115115
self.eval_max_length = self.model.config.max_length
116+
if self.hparams.eval_min_gen_length is not None:
117+
self.eval_min_length = self.hparams.eval_min_gen_length
118+
else:
119+
self.eval_min_length = self.model.config.min_length
116120
self.val_metric = self.default_val_metric if self.hparams.val_metric is None else self.hparams.val_metric
117121

118122
def save_readable_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, List[str]]:
@@ -219,6 +223,7 @@ def _generative_step(self, batch: dict) -> dict:
219223
decoder_start_token_id=self.decoder_start_token_id,
220224
num_beams=self.eval_beams,
221225
max_length=self.eval_max_length,
226+
min_length=self.eval_min_length,
222227
)
223228
gen_time = (time.time() - t0) / batch["input_ids"].shape[0]
224229
preds: List[str] = self.ids_to_clean_text(generated_ids)
@@ -346,6 +351,7 @@ def add_model_specific_args(parser, root_dir):
346351
"--val_metric", type=str, default=None, required=False, choices=["bleu", "rouge2", "loss", None]
347352
)
348353
parser.add_argument("--eval_max_gen_length", type=int, default=None, help="never generate more than n tokens")
354+
parser.add_argument("--eval_min_gen_length", type=int, default=None, help="never generate shorter than n tokens")
349355
parser.add_argument("--save_top_k", type=int, default=1, required=False, help="How many checkpoints to save")
350356
parser.add_argument(
351357
"--early_stopping_patience",

0 commit comments

Comments
 (0)