Skip to content

Commit 622a8c5

Browse files
stas00sgugger
andauthored
[trainer] add Trainer methods for metrics logging and saving (huggingface#10266)
* make logging and saving trainer built-in * Update src/transformers/trainer.py Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
1 parent 94d8767 commit 622a8c5

File tree

2 files changed

+40
-28
lines changed

2 files changed

+40
-28
lines changed

examples/seq2seq/run_seq2seq.py

Lines changed: 7 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
"""
1919
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
2020

21-
import json
2221
import logging
2322
import os
2423
import re
@@ -55,11 +54,6 @@
5554
logger = logging.getLogger(__name__)
5655

5756

58-
def save_json(content, path, indent=4, **json_dump_kwargs):
59-
with open(path, "w") as f:
60-
json.dump(content, f, indent=indent, sort_keys=True, **json_dump_kwargs)
61-
62-
6357
@dataclass
6458
class ModelArguments:
6559
"""
@@ -596,13 +590,8 @@ def compute_metrics(eval_preds):
596590
)
597591
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
598592
if trainer.is_world_process_zero():
599-
metrics_formatted = trainer.metrics_format(metrics)
600-
logger.info("***** train metrics *****")
601-
k_width = max(len(str(x)) for x in metrics_formatted.keys())
602-
v_width = max(len(str(x)) for x in metrics_formatted.values())
603-
for key in sorted(metrics_formatted.keys()):
604-
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
605-
save_json(metrics, os.path.join(training_args.output_dir, "train_results.json"))
593+
trainer.log_metrics("train", metrics)
594+
trainer.save_metrics("train", metrics)
606595
all_metrics.update(metrics)
607596

608597
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
@@ -620,13 +609,8 @@ def compute_metrics(eval_preds):
620609
metrics["eval_samples"] = min(max_val_samples, len(eval_dataset))
621610

622611
if trainer.is_world_process_zero():
623-
metrics_formatted = trainer.metrics_format(metrics)
624-
logger.info("***** val metrics *****")
625-
k_width = max(len(str(x)) for x in metrics_formatted.keys())
626-
v_width = max(len(str(x)) for x in metrics_formatted.values())
627-
for key in sorted(metrics_formatted.keys()):
628-
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
629-
save_json(metrics, os.path.join(training_args.output_dir, "eval_results.json"))
612+
trainer.log_metrics("eval", metrics)
613+
trainer.save_metrics("eval", metrics)
630614
all_metrics.update(metrics)
631615

632616
if training_args.do_predict:
@@ -643,13 +627,8 @@ def compute_metrics(eval_preds):
643627
metrics["test_samples"] = min(max_test_samples, len(test_dataset))
644628

645629
if trainer.is_world_process_zero():
646-
metrics_formatted = trainer.metrics_format(metrics)
647-
logger.info("***** test metrics *****")
648-
k_width = max(len(str(x)) for x in metrics_formatted.keys())
649-
v_width = max(len(str(x)) for x in metrics_formatted.values())
650-
for key in sorted(metrics_formatted.keys()):
651-
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
652-
save_json(metrics, os.path.join(training_args.output_dir, "test_results.json"))
630+
trainer.log_metrics("test", metrics)
631+
trainer.save_metrics("test", metrics)
653632
all_metrics.update(metrics)
654633

655634
if training_args.predict_with_generate:
@@ -662,7 +641,7 @@ def compute_metrics(eval_preds):
662641
writer.write("\n".join(test_preds))
663642

664643
if trainer.is_world_process_zero():
665-
save_json(all_metrics, os.path.join(training_args.output_dir, "all_results.json"))
644+
trainer.save_metrics("all", metrics)
666645

667646
return results
668647

src/transformers/trainer.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import collections
2020
import gc
2121
import inspect
22+
import json
2223
import math
2324
import os
2425
import re
@@ -1370,6 +1371,38 @@ def metrics_format(self, metrics: Dict[str, float]) -> Dict[str, float]:
13701371

13711372
return metrics_copy
13721373

1374+
def log_metrics(self, split, metrics):
1375+
"""
1376+
Log metrics in a specially formatted way
1377+
1378+
Args:
1379+
split (:obj:`str`):
1380+
Mode/split name: one of ``train``, ``eval``, ``test``
1381+
metrics (:obj:`Dict[str, float]`):
1382+
The metrics returned from train/evaluate/predictmetrics: metrics dict
1383+
"""
1384+
1385+
logger.info(f"***** {split} metrics *****")
1386+
metrics_formatted = self.metrics_format(metrics)
1387+
k_width = max(len(str(x)) for x in metrics_formatted.keys())
1388+
v_width = max(len(str(x)) for x in metrics_formatted.values())
1389+
for key in sorted(metrics_formatted.keys()):
1390+
logger.info(f" {key: <{k_width}} = {metrics_formatted[key]:>{v_width}}")
1391+
1392+
def save_metrics(self, split, metrics):
1393+
"""
1394+
Save metrics into a json file for that split, e.g. ``train_results.json``.
1395+
1396+
Args:
1397+
split (:obj:`str`):
1398+
Mode/split name: one of ``train``, ``eval``, ``test``, ``all``
1399+
metrics (:obj:`Dict[str, float]`):
1400+
The metrics returned from train/evaluate/predict
1401+
"""
1402+
path = os.path.join(self.args.output_dir, f"{split}_results.json")
1403+
with open(path, "w") as f:
1404+
json.dump(metrics, f, indent=4, sort_keys=True)
1405+
13731406
def _prepare_inputs(self, inputs: Dict[str, Union[torch.Tensor, Any]]) -> Dict[str, Union[torch.Tensor, Any]]:
13741407
"""
13751408
Prepare :obj:`inputs` before feeding them to the model, converting them to tensors if they are not already and

0 commit comments

Comments
 (0)