Skip to content

Commit da5ef25

Browse files
authored
Push to hub save (huggingface#15327)
* Adapt doc and push at every save * style
1 parent 9f831bd commit da5ef25

File tree

2 files changed

+21
-7
lines changed

2 files changed

+21
-7
lines changed

src/transformers/trainer.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ def _tune_save_checkpoint(self):
966966
return
967967
with tune.checkpoint_dir(step=self.state.global_step) as checkpoint_dir:
968968
output_dir = os.path.join(checkpoint_dir, f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}")
969-
self.save_model(output_dir)
969+
self.save_model(output_dir, _internal_call=True)
970970
if self.args.should_save:
971971
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))
972972
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
@@ -1634,7 +1634,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
16341634
self.store_flos()
16351635

16361636
output_dir = os.path.join(run_dir, checkpoint_folder)
1637-
self.save_model(output_dir)
1637+
self.save_model(output_dir, _internal_call=True)
16381638
if self.deepspeed:
16391639
# under zero3 model file itself doesn't get saved since it's bogus! Unless deepspeed
16401640
# config `stage3_gather_fp16_weights_on_model_save` is True
@@ -2002,7 +2002,7 @@ def is_world_process_zero(self) -> bool:
20022002
else:
20032003
return self.args.process_index == 0
20042004

2005-
def save_model(self, output_dir: Optional[str] = None):
2005+
def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
20062006
"""
20072007
Will save the model, so you can reload it using `from_pretrained()`.
20082008
@@ -2051,6 +2051,10 @@ def save_model(self, output_dir: Optional[str] = None):
20512051
elif self.args.should_save:
20522052
self._save(output_dir)
20532053

2054+
# Push to the Hub when `save_model` is called by the user.
2055+
if self.args.push_to_hub and not _internal_call:
2056+
self.push_to_hub(commit_message="Model save")
2057+
20542058
def _save_tpu(self, output_dir: Optional[str] = None):
20552059
output_dir = output_dir if output_dir is not None else self.args.output_dir
20562060
logger.info(f"Saving model checkpoint to {output_dir}")
@@ -2768,9 +2772,10 @@ def push_to_hub(self, commit_message: Optional[str] = "End of training", blockin
27682772
model_name = Path(self.args.output_dir).name
27692773
else:
27702774
model_name = self.args.hub_model_id.split("/")[-1]
2775+
27712776
# Needs to be executed on all processes for TPU training, but will only save on the processed determined by
27722777
# self.args.should_save.
2773-
self.save_model()
2778+
self.save_model(_internal_call=True)
27742779

27752780
# Only push from one node.
27762781
if not self.is_world_process_zero():

src/transformers/training_args.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -365,9 +365,18 @@ class TrainingArguments:
365365
Whether to skip adding of memory profiler reports to metrics. This is skipped by default because it slows
366366
down the training and evaluation speed.
367367
push_to_hub (`bool`, *optional*, defaults to `False`):
368-
Whether or not to upload the trained model to the hub after training. If this is activated, and
369-
`output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be
368+
Whether or not to push the model to the Hub every time the model is saved. If this is activated,
369+
`output_dir` will begin a git directory synced with the the repo (determined by `hub_model_id`) and the
370+
content will be pushed each time a save is triggered (depneding on your `save_strategy`). Calling
371+
[`~Trainer.save_model`] will also trigger a push
372+
373+
<Tip warning={true}>
374+
375+
If `output_dir` exists, it needs to be a local clone of the repository to which the [`Trainer`] will be
370376
pushed.
377+
378+
</Tip>
379+
371380
resume_from_checkpoint (`str`, *optional*):
372381
The path to a folder with a valid checkpoint for your model. This argument is not directly used by
373382
[`Trainer`], it's intended to be used by your training/evaluation scripts instead. See the [example
@@ -384,7 +393,7 @@ class TrainingArguments:
384393
Defines the scope of what is pushed to the Hub and when. Possible values are:
385394
386395
- `"end"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and a
387-
draft of a model card at the end of training.
396+
draft of a model card when the [`~Trainer.save_model`] method is called.
388397
- `"every_save"`: push the model, its configuration, the tokenizer (if passed along to the [`Trainer`]) and
389398
a draft of a model card each time there is a model save. The pushes are asynchronous to not block
390399
training, and in case the save are very frequent, a new push is only attempted if the previous one is

0 commit comments

Comments
 (0)