Skip to content

Commit c76de10

Browse files
authored
Add generate kwargs to Seq2SeqTrainingArguments (huggingface#13339)
* Add generate kwargs to Seq2SeqTrainingArguments * typo * Address review comments + doc * Style
1 parent 702f4a4 commit c76de10

File tree

4 files changed

+41
-23
lines changed

4 files changed

+41
-23
lines changed

examples/pytorch/summarization/run_summarization.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -556,12 +556,15 @@ def compute_metrics(eval_preds):
556556

557557
# Evaluation
558558
results = {}
559+
max_length = (
560+
training_args.generation_max_length
561+
if training_args.generation_max_length is not None
562+
else data_args.val_max_target_length
563+
)
564+
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
559565
if training_args.do_eval:
560566
logger.info("*** Evaluate ***")
561-
562-
metrics = trainer.evaluate(
563-
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
564-
)
567+
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
565568
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
566569
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
567570

@@ -572,10 +575,7 @@ def compute_metrics(eval_preds):
572575
logger.info("*** Predict ***")
573576

574577
predict_results = trainer.predict(
575-
predict_dataset,
576-
metric_key_prefix="predict",
577-
max_length=data_args.val_max_target_length,
578-
num_beams=data_args.num_beams,
578+
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
579579
)
580580
metrics = predict_results.metrics
581581
max_predict_samples = (

examples/pytorch/translation/run_translation.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -549,12 +549,16 @@ def compute_metrics(eval_preds):
549549

550550
# Evaluation
551551
results = {}
552+
max_length = (
553+
training_args.generation_max_length
554+
if training_args.generation_max_length is not None
555+
else data_args.val_max_target_length
556+
)
557+
num_beams = data_args.num_beams if data_args.num_beams is not None else training_args.generation_num_beams
552558
if training_args.do_eval:
553559
logger.info("*** Evaluate ***")
554560

555-
metrics = trainer.evaluate(
556-
max_length=data_args.val_max_target_length, num_beams=data_args.num_beams, metric_key_prefix="eval"
557-
)
561+
metrics = trainer.evaluate(max_length=max_length, num_beams=num_beams, metric_key_prefix="eval")
558562
max_eval_samples = data_args.max_eval_samples if data_args.max_eval_samples is not None else len(eval_dataset)
559563
metrics["eval_samples"] = min(max_eval_samples, len(eval_dataset))
560564

@@ -565,10 +569,7 @@ def compute_metrics(eval_preds):
565569
logger.info("*** Predict ***")
566570

567571
predict_results = trainer.predict(
568-
predict_dataset,
569-
metric_key_prefix="predict",
570-
max_length=data_args.val_max_target_length,
571-
num_beams=data_args.num_beams,
572+
predict_dataset, metric_key_prefix="predict", max_length=max_length, num_beams=num_beams
572573
)
573574
metrics = predict_results.metrics
574575
max_predict_samples = (

src/transformers/trainer_seq2seq.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,8 @@ def evaluate(
7070
A dictionary containing the evaluation loss and the potential metrics computed from the predictions. The
7171
dictionary also contains the epoch number which comes from the training state.
7272
"""
73-
if max_length is not None or not hasattr(self, "_max_length"):
74-
self._max_length = max_length
75-
if num_beams is not None or not hasattr(self, "_num_beams"):
76-
self._num_beams = num_beams
73+
self._max_length = max_length if max_length is not None else self.args.generation_max_length
74+
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
7775
return super().evaluate(eval_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
7876

7977
def predict(
@@ -119,10 +117,8 @@ def predict(
119117
- metrics (:obj:`Dict[str, float]`, `optional`): The potential dictionary of metrics (if the dataset
120118
contained labels).
121119
"""
122-
if max_length is not None or not hasattr(self, "_max_length"):
123-
self._max_length = max_length
124-
if num_beams is not None or not hasattr(self, "_num_beams"):
125-
self._num_beams = num_beams
120+
self._max_length = max_length if max_length is not None else self.args.generation_max_length
121+
self._num_beams = num_beams if num_beams is not None else self.args.generation_num_beams
126122
return super().predict(test_dataset, ignore_keys=ignore_keys, metric_key_prefix=metric_key_prefix)
127123

128124
def prediction_step(

src/transformers/training_args_seq2seq.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import logging
1616
from dataclasses import dataclass, field
17+
from typing import Optional
1718

1819
from .file_utils import add_start_docstrings
1920
from .training_args import TrainingArguments
@@ -34,9 +35,29 @@ class Seq2SeqTrainingArguments(TrainingArguments):
3435
the training set.
3536
predict_with_generate (:obj:`bool`, `optional`, defaults to :obj:`False`):
3637
Whether to use generate to calculate generative metrics (ROUGE, BLEU).
38+
generation_max_length (:obj:`int`, `optional`):
39+
The :obj:`max_length` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to
40+
the :obj:`max_length` value of the model configuration.
41+
generation_num_beams (:obj:`int`, `optional`):
42+
The :obj:`num_beams` to use on each evaluation loop when :obj:`predict_with_generate=True`. Will default to the
43+
:obj:`num_beams` value of the model configuration.
3744
"""
3845

3946
sortish_sampler: bool = field(default=False, metadata={"help": "Whether to use SortishSampler or not."})
4047
predict_with_generate: bool = field(
4148
default=False, metadata={"help": "Whether to use generate to calculate generative metrics (ROUGE, BLEU)."}
4249
)
50+
generation_max_length: Optional[int] = field(
51+
default=None,
52+
metadata={
53+
"help": "The `max_length` to use on each evaluation loop when `predict_with_generate=True`. Will default "
54+
"to the `max_length` value of the model configuration."
55+
},
56+
)
57+
generation_num_beams: Optional[int] = field(
58+
default=None,
59+
metadata={
60+
"help": "The `num_beams` to use on each evaluation loop when `predict_with_generate=True`. Will default "
61+
"to the `num_beams` value of the model configuration."
62+
},
63+
)

0 commit comments

Comments
 (0)