|
97 | 97 | TrainOutput,
|
98 | 98 | default_compute_objective,
|
99 | 99 | default_hp_space,
|
| 100 | + get_last_checkpoint, |
100 | 101 | set_seed,
|
101 | 102 | speed_metrics,
|
102 | 103 | )
|
@@ -758,17 +759,19 @@ def _wrap_model(self, model, training=True):
|
758 | 759 |
|
759 | 760 | def train(
|
760 | 761 | self,
|
761 |
| - resume_from_checkpoint: Optional[str] = None, |
| 762 | + resume_from_checkpoint: Optional[Union[str, bool]] = None, |
762 | 763 | trial: Union["optuna.Trial", Dict[str, Any]] = None,
|
763 | 764 | **kwargs,
|
764 | 765 | ):
|
765 | 766 | """
|
766 | 767 | Main training entry point.
|
767 | 768 |
|
768 | 769 | Args:
|
769 |
| - resume_from_checkpoint (:obj:`str`, `optional`): |
770 |
| - Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If |
771 |
| - present, training will resume from the model/optimizer/scheduler states loaded here. |
| 770 | + resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`): |
| 771 | + If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of |
| 772 | + :class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in |
| 773 | + `args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present, |
| 774 | + training will resume from the model/optimizer/scheduler states loaded here. |
772 | 775 | trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
773 | 776 | The trial run or the hyperparameter dictionary for hyperparameter search.
|
774 | 777 | kwargs:
|
@@ -803,6 +806,11 @@ def train(
|
803 | 806 | self.optimizer, self.lr_scheduler = None, None
|
804 | 807 |
|
805 | 808 | # Load potential model checkpoint
|
| 809 | + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: |
| 810 | + resume_from_checkpoint = get_last_checkpoint(self.args.output_dir) |
| 811 | + if resume_from_checkpoint is None: |
| 812 | + raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})") |
| 813 | + |
806 | 814 | if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
|
807 | 815 | logger.info(f"Loading model from {resume_from_checkpoint}).")
|
808 | 816 | if isinstance(self.model, PreTrainedModel):
|
|
0 commit comments