18
18
"""
19
19
# You can also adapt this script on your own sequence to sequence task. Pointers for this are left as comments.
20
20
21
- import json
22
21
import logging
23
22
import os
24
23
import re
55
54
logger = logging .getLogger (__name__ )
56
55
57
56
58
- def save_json (content , path , indent = 4 , ** json_dump_kwargs ):
59
- with open (path , "w" ) as f :
60
- json .dump (content , f , indent = indent , sort_keys = True , ** json_dump_kwargs )
61
-
62
-
63
57
@dataclass
64
58
class ModelArguments :
65
59
"""
@@ -596,13 +590,8 @@ def compute_metrics(eval_preds):
596
590
)
597
591
metrics ["train_samples" ] = min (max_train_samples , len (train_dataset ))
598
592
if trainer .is_world_process_zero ():
599
- metrics_formatted = trainer .metrics_format (metrics )
600
- logger .info ("***** train metrics *****" )
601
- k_width = max (len (str (x )) for x in metrics_formatted .keys ())
602
- v_width = max (len (str (x )) for x in metrics_formatted .values ())
603
- for key in sorted (metrics_formatted .keys ()):
604
- logger .info (f" { key : <{k_width }} = { metrics_formatted [key ]:>{v_width }} " )
605
- save_json (metrics , os .path .join (training_args .output_dir , "train_results.json" ))
593
+ trainer .log_metrics ("train" , metrics )
594
+ trainer .save_metrics ("train" , metrics )
606
595
all_metrics .update (metrics )
607
596
608
597
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
@@ -620,13 +609,8 @@ def compute_metrics(eval_preds):
620
609
metrics ["eval_samples" ] = min (max_val_samples , len (eval_dataset ))
621
610
622
611
if trainer .is_world_process_zero ():
623
- metrics_formatted = trainer .metrics_format (metrics )
624
- logger .info ("***** val metrics *****" )
625
- k_width = max (len (str (x )) for x in metrics_formatted .keys ())
626
- v_width = max (len (str (x )) for x in metrics_formatted .values ())
627
- for key in sorted (metrics_formatted .keys ()):
628
- logger .info (f" { key : <{k_width }} = { metrics_formatted [key ]:>{v_width }} " )
629
- save_json (metrics , os .path .join (training_args .output_dir , "eval_results.json" ))
612
+ trainer .log_metrics ("eval" , metrics )
613
+ trainer .save_metrics ("eval" , metrics )
630
614
all_metrics .update (metrics )
631
615
632
616
if training_args .do_predict :
@@ -643,13 +627,8 @@ def compute_metrics(eval_preds):
643
627
metrics ["test_samples" ] = min (max_test_samples , len (test_dataset ))
644
628
645
629
if trainer .is_world_process_zero ():
646
- metrics_formatted = trainer .metrics_format (metrics )
647
- logger .info ("***** test metrics *****" )
648
- k_width = max (len (str (x )) for x in metrics_formatted .keys ())
649
- v_width = max (len (str (x )) for x in metrics_formatted .values ())
650
- for key in sorted (metrics_formatted .keys ()):
651
- logger .info (f" { key : <{k_width }} = { metrics_formatted [key ]:>{v_width }} " )
652
- save_json (metrics , os .path .join (training_args .output_dir , "test_results.json" ))
630
+ trainer .log_metrics ("test" , metrics )
631
+ trainer .save_metrics ("test" , metrics )
653
632
all_metrics .update (metrics )
654
633
655
634
if training_args .predict_with_generate :
@@ -662,7 +641,7 @@ def compute_metrics(eval_preds):
662
641
writer .write ("\n " .join (test_preds ))
663
642
664
643
if trainer .is_world_process_zero ():
665
- save_json ( all_metrics , os . path . join ( training_args . output_dir , "all_results.json" ) )
644
+ trainer . save_metrics ( "all" , metrics )
666
645
667
646
return results
668
647
0 commit comments