Skip to content

Commit 94d8767

Browse files
authored
Loading from last checkpoint functionality in Trainer.train (huggingface#10334)
Enhance resume_from_checkpoint argument of Trainer.train to accept bool type. If True given, last saved checkpoint in self.args.output_dir will be loaded. (huggingface#10280)
1 parent eab0afc commit 94d8767

File tree

1 file changed

+12
-4
lines changed

1 file changed

+12
-4
lines changed

src/transformers/trainer.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@
9797
TrainOutput,
9898
default_compute_objective,
9999
default_hp_space,
100+
get_last_checkpoint,
100101
set_seed,
101102
speed_metrics,
102103
)
@@ -758,17 +759,19 @@ def _wrap_model(self, model, training=True):
758759

759760
def train(
760761
self,
761-
resume_from_checkpoint: Optional[str] = None,
762+
resume_from_checkpoint: Optional[Union[str, bool]] = None,
762763
trial: Union["optuna.Trial", Dict[str, Any]] = None,
763764
**kwargs,
764765
):
765766
"""
766767
Main training entry point.
767768
768769
Args:
769-
resume_from_checkpoint (:obj:`str`, `optional`):
770-
Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If
771-
present, training will resume from the model/optimizer/scheduler states loaded here.
770+
resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
771+
If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
772+
:class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
773+
`args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present,
774+
training will resume from the model/optimizer/scheduler states loaded here.
772775
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
773776
The trial run or the hyperparameter dictionary for hyperparameter search.
774777
kwargs:
@@ -803,6 +806,11 @@ def train(
803806
self.optimizer, self.lr_scheduler = None, None
804807

805808
# Load potential model checkpoint
809+
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
810+
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
811+
if resume_from_checkpoint is None:
812+
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
813+
806814
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
807815
logger.info(f"Loading model from {resume_from_checkpoint}).")
808816
if isinstance(self.model, PreTrainedModel):

0 commit comments

Comments
 (0)