Skip to content

Commit 48d4827

Browse files
TF model cards (huggingface#14720)
* Initial commit for Keras model cards * Revert accidental change * make style * make style * make style * Fix PR comments * Move repo creation to __init__ * Fixes to README.md creation * Partial progress for proper card creation on `push_to_hub` * Proper card creation from `push_to_hub` plus fixes for malformed model cards * Fixes for model card creation outside the callback * Adding a model card creation test * Putting the model card creation test in the right file. Good job, Matt. * make style * Fix model card test temp dir usage * Fix model card creation when no optimizer present * Fixes for when training history not present * Fix accidental edit to test_modeling_common
1 parent 72c6e8b commit 48d4827

File tree

5 files changed

+221
-7
lines changed

5 files changed

+221
-7
lines changed

src/transformers/file_utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2335,6 +2335,7 @@ def push_to_hub(
23352335
organization: Optional[str] = None,
23362336
private: Optional[bool] = None,
23372337
use_auth_token: Optional[Union[bool, str]] = None,
2338+
**model_card_kwargs
23382339
) -> str:
23392340
"""
23402341
Upload the {object_files} to the 🤗 Model Hub while synchronizing a local clone of the repo in
@@ -2409,6 +2410,14 @@ def push_to_hub(
24092410
)
24102411
# Save the files in the cloned repo
24112412
self.save_pretrained(repo_path_or_name)
2413+
if hasattr(self, "history") and hasattr(self, "create_model_card"):
2414+
# This is a Keras model and we might be able to fish out its History and make a model card out of it
2415+
base_model_card_args = {
2416+
"output_dir": repo_path_or_name,
2417+
"model_name": Path(repo_path_or_name).name,
2418+
}
2419+
base_model_card_args.update(model_card_kwargs)
2420+
self.create_model_card(**base_model_card_args)
24122421
# Commit and push!
24132422
url = self._push_to_hub(repo, commit_message=commit_message)
24142423

src/transformers/keras_callbacks.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
from . import IntervalStrategy, PreTrainedTokenizerBase
1212
from .file_utils import get_full_repo_name
13+
from .modelcard import TrainingSummary
1314

1415

1516
logger = logging.getLogger(__name__)
@@ -25,6 +26,7 @@ def __init__(
2526
hub_model_id: Optional[str] = None,
2627
hub_token: Optional[str] = None,
2728
checkpoint: bool = False,
29+
**model_card_args
2830
):
2931
"""
3032
output_dir (:obj:`str`):
@@ -70,12 +72,22 @@ def __init__(
7072
hub_model_id = get_full_repo_name(hub_model_id, token=hub_token)
7173

7274
self.output_dir = output_dir
75+
self.hub_model_id = hub_model_id
7376
self.repo = Repository(
74-
str(output_dir), clone_from=hub_model_id, use_auth_token=hub_token if hub_token else True
77+
str(self.output_dir),
78+
clone_from=self.hub_model_id,
79+
use_auth_token=hub_token if hub_token else True,
7580
)
7681
self.tokenizer = tokenizer
7782
self.last_job = None
7883
self.checkpoint = checkpoint
84+
self.training_history = None
85+
self.model_card_args = model_card_args
86+
87+
def on_train_begin(self, logs=None):
88+
# Although we can access model.history, we have no guarantees that the History callback will fire before this
89+
# one, so we keep track of it here too
90+
self.training_history = []
7991

8092
def on_train_batch_end(self, batch, logs=None):
8193
if self.save_strategy == IntervalStrategy.STEPS and batch + 1 % self.save_steps == 0:
@@ -89,6 +101,9 @@ def on_train_batch_end(self, batch, logs=None):
89101
)
90102

91103
def on_epoch_end(self, epoch, logs=None):
104+
if "epoch" not in logs:
105+
logs["epoch"] = epoch
106+
self.training_history.append(logs)
92107
if self.save_strategy == IntervalStrategy.EPOCH:
93108
if self.last_job is not None and not self.last_job.is_done:
94109
return # The last upload is still running, don't start another
@@ -98,6 +113,15 @@ def on_epoch_end(self, epoch, logs=None):
98113
if self.checkpoint:
99114
checkpoint_dir = os.path.join(self.output_dir, "checkpoint")
100115
self.model._save_checkpoint(checkpoint_dir, epoch)
116+
train_summary = TrainingSummary.from_keras(
117+
model=self.model,
118+
model_name=self.hub_model_id,
119+
keras_history=self.training_history,
120+
**self.model_card_args,
121+
)
122+
model_card = train_summary.to_model_card()
123+
with (self.output_dir / "README.md").open("w") as f:
124+
f.write(model_card)
101125
_, self.last_job = self.repo.push_to_hub(
102126
commit_message=f"Training in progress epoch {epoch}", blocking=False
103127
)
@@ -110,4 +134,10 @@ def on_train_end(self, logs=None):
110134
self.model.save_pretrained(self.output_dir)
111135
if self.tokenizer is not None:
112136
self.tokenizer.save_pretrained(self.output_dir)
137+
train_summary = TrainingSummary.from_keras(
138+
model=self.model, model_name=self.hub_model_id, keras_history=self.training_history, **self.model_card_args
139+
)
140+
model_card = train_summary.to_model_card()
141+
with (self.output_dir / "README.md").open("w") as f:
142+
f.write(model_card)
113143
self.repo.push_to_hub(commit_message="End of training", blocking=True)

src/transformers/modelcard.py

Lines changed: 141 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
is_datasets_available,
3939
is_offline_mode,
4040
is_remote_url,
41+
is_tf_available,
4142
is_tokenizers_available,
4243
is_torch_available,
4344
)
@@ -266,11 +267,16 @@ def to_json_file(self, json_file_path):
266267
writer.write(self.to_json_string())
267268

268269

269-
AUTOGENERATED_COMMENT = """
270+
AUTOGENERATED_TRAINER_COMMENT = """
270271
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
271272
should probably proofread and complete it, then remove this comment. -->
272273
"""
273274

275+
AUTOGENERATED_KERAS_COMMENT = """
276+
<!-- This model card has been generated automatically according to the information Keras had access to. You should
277+
probably proofread and complete it, then remove this comment. -->
278+
"""
279+
274280

275281
TASK_TAG_TO_NAME_MAPPING = {
276282
"fill-mask": "Masked Language Modeling",
@@ -377,6 +383,7 @@ class TrainingSummary:
377383
eval_results: Optional[Dict[str, float]] = None
378384
eval_lines: Optional[List[str]] = None
379385
hyperparameters: Optional[Dict[str, Any]] = None
386+
source: Optional[str] = "trainer"
380387

381388
def __post_init__(self):
382389
# Infer default license from the checkpoint used, if possible.
@@ -410,15 +417,15 @@ def create_model_index(self, metric_mapping):
410417
task: TASK_TAG_TO_NAME_MAPPING[task] for task in _listify(self.tasks) if task in TASK_TAG_TO_NAME_MAPPING
411418
}
412419

420+
model_index["results"] = []
421+
413422
if len(task_mapping) == 0 and len(dataset_mapping) == 0:
414-
return model_index
423+
return [model_index]
415424
if len(task_mapping) == 0:
416425
task_mapping = {None: None}
417426
if len(dataset_mapping) == 0:
418427
dataset_mapping = {None: None}
419428

420-
model_index["results"] = []
421-
422429
# One entry per dataset and per task
423430
all_possibilities = [(task_tag, ds_tag) for task_tag in task_mapping for ds_tag in dataset_mapping]
424431
for task_tag, ds_tag in all_possibilities:
@@ -471,7 +478,10 @@ def to_model_card(self):
471478
model_card = f"---\n{metadata}---\n"
472479

473480
# Now the model card for realsies.
474-
model_card += AUTOGENERATED_COMMENT
481+
if self.source == "trainer":
482+
model_card += AUTOGENERATED_TRAINER_COMMENT
483+
else:
484+
model_card += AUTOGENERATED_KERAS_COMMENT
475485

476486
model_card += f"\n# {self.model_name}\n\n"
477487

@@ -517,10 +527,15 @@ def to_model_card(self):
517527

518528
model_card += "\n### Framework versions\n\n"
519529
model_card += f"- Transformers {__version__}\n"
520-
if is_torch_available():
530+
531+
if self.source == "trainer" and is_torch_available():
521532
import torch
522533

523534
model_card += f"- Pytorch {torch.__version__}\n"
535+
elif self.source == "keras" and is_tf_available():
536+
import tensorflow as tf
537+
538+
model_card += f"- TensorFlow {tf.__version__}\n"
524539
if is_datasets_available():
525540
import datasets
526541

@@ -604,6 +619,113 @@ def from_trainer(
604619
hyperparameters=hyperparameters,
605620
)
606621

622+
@classmethod
623+
def from_keras(
624+
cls,
625+
model,
626+
model_name,
627+
keras_history=None,
628+
language=None,
629+
license=None,
630+
tags=None,
631+
finetuned_from=None,
632+
tasks=None,
633+
dataset_tags=None,
634+
dataset=None,
635+
dataset_args=None,
636+
):
637+
# Infer default from dataset
638+
if dataset is not None:
639+
if is_hf_dataset(dataset) and (dataset_tags is None or dataset_args is None):
640+
default_tag = dataset.builder_name
641+
# Those are not real datasets from the Hub so we exclude them.
642+
if default_tag not in ["csv", "json", "pandas", "parquet", "text"]:
643+
if dataset_tags is None:
644+
dataset_tags = [default_tag]
645+
if dataset_args is None:
646+
dataset_args = [dataset.config_name]
647+
648+
if dataset is None and dataset_tags is not None:
649+
dataset = dataset_tags
650+
651+
# Infer default finetuned_from
652+
if (
653+
finetuned_from is None
654+
and hasattr(model.config, "_name_or_path")
655+
and not os.path.isdir(model.config._name_or_path)
656+
):
657+
finetuned_from = model.config._name_or_path
658+
659+
# Infer default task tag:
660+
if tasks is None:
661+
model_class_name = model.__class__.__name__
662+
for task, mapping in TASK_MAPPING.items():
663+
if model_class_name in _get_mapping_values(mapping):
664+
tasks = task
665+
666+
# Add `generated_from_keras_callback` to the tags
667+
if tags is None:
668+
tags = ["generated_from_keras_callback"]
669+
elif isinstance(tags, str) and tags != "generated_from_keras_callback":
670+
tags = [tags, "generated_from_keras_callback"]
671+
elif "generated_from_trainer" not in tags:
672+
tags.append("generated_from_keras_callback")
673+
674+
if keras_history is not None:
675+
_, eval_lines, eval_results = parse_keras_history(keras_history)
676+
else:
677+
eval_lines = []
678+
eval_results = dict()
679+
hyperparameters = extract_hyperparameters_from_keras(model)
680+
681+
return cls(
682+
language=language,
683+
license=license,
684+
tags=tags,
685+
model_name=model_name,
686+
finetuned_from=finetuned_from,
687+
tasks=tasks,
688+
dataset_tags=dataset_tags,
689+
dataset=dataset,
690+
dataset_args=dataset_args,
691+
eval_results=eval_results,
692+
eval_lines=eval_lines,
693+
hyperparameters=hyperparameters,
694+
source="keras",
695+
)
696+
697+
698+
def parse_keras_history(logs):
699+
"""
700+
Parse the `logs` of either a `tf.keras.History` object returned by `model.fit()` or an accumulated logs `dict`
701+
passed to the `PushToHubCallback`. Returns lines and logs compatible with those returned by `parse_log_history`.
702+
"""
703+
if hasattr(logs, "history"):
704+
# This looks like a `History` object
705+
logs.history["epoch"] = logs.epoch
706+
logs = logs.history
707+
else:
708+
# Training logs is a list of dicts, let's invert it to a dict of lists to match a History object
709+
logs = {log_key: [single_dict[log_key] for single_dict in logs] for log_key in logs[0]}
710+
711+
lines = []
712+
for i in range(len(logs["epoch"])):
713+
epoch_dict = {log_key: log_value_list[i] for log_key, log_value_list in logs.items()}
714+
values = dict()
715+
for k, v in epoch_dict.items():
716+
if k.startswith("val_"):
717+
k = "validation_" + k[4:]
718+
elif k != "epoch":
719+
k = "train_" + k
720+
splits = k.split("_")
721+
name = " ".join([part.capitalize() for part in splits])
722+
values[name] = v
723+
lines.append(values)
724+
725+
eval_results = lines[-1]
726+
727+
return logs, lines, eval_results
728+
607729

608730
def parse_log_history(log_history):
609731
"""
@@ -666,6 +788,19 @@ def parse_log_history(log_history):
666788
return train_log, lines, None
667789

668790

791+
def extract_hyperparameters_from_keras(model):
792+
import tensorflow as tf
793+
794+
hyperparameters = dict()
795+
if hasattr(model, "optimizer") and model.optimizer is not None:
796+
hyperparameters["optimizer"] = model.optimizer.get_config()
797+
else:
798+
hyperparameters["optimizer"] = None
799+
hyperparameters["training_precision"] = tf.keras.mixed_precision.global_policy().name
800+
801+
return hyperparameters
802+
803+
669804
def _maybe_round(v, decimals=4):
670805
if isinstance(v, float) and len(str(v).split(".")) > 1 and len(str(v).split(".")[1]) > decimals:
671806
return f"{v:.{decimals}f}"

src/transformers/modeling_tf_utils.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
is_remote_url,
4848
)
4949
from .generation_tf_utils import TFGenerationMixin
50+
from .modelcard import TrainingSummary
5051
from .modeling_tf_outputs import TFSeq2SeqLMOutput
5152
from .tokenization_utils_base import BatchEncoding
5253
from .utils import logging
@@ -926,6 +927,36 @@ def test_step(self, data):
926927
del return_metrics["loss_loss"]
927928
return return_metrics
928929

930+
def create_model_card(
931+
self,
932+
output_dir,
933+
model_name: str,
934+
language: Optional[str] = None,
935+
license: Optional[str] = None,
936+
tags: Optional[str] = None,
937+
finetuned_from: Optional[str] = None,
938+
tasks: Optional[str] = None,
939+
dataset_tags: Optional[Union[str, List[str]]] = None,
940+
dataset: Optional[Union[str, List[str]]] = None,
941+
dataset_args: Optional[Union[str, List[str]]] = None,
942+
):
943+
training_summary = TrainingSummary.from_keras(
944+
self,
945+
keras_history=self.history,
946+
language=language,
947+
license=license,
948+
tags=tags,
949+
model_name=model_name,
950+
finetuned_from=finetuned_from,
951+
tasks=tasks,
952+
dataset_tags=dataset_tags,
953+
dataset=dataset,
954+
dataset_args=dataset_args,
955+
)
956+
model_card = training_summary.to_model_card()
957+
with open(os.path.join(output_dir, "README.md"), "w") as f:
958+
f.write(model_card)
959+
929960
def set_input_embeddings(self, value):
930961
"""
931962
Set model's input embeddings

tests/test_modeling_tf_common.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1386,6 +1386,15 @@ def test_push_to_hub(self):
13861386
models_equal = False
13871387
self.assertTrue(models_equal)
13881388

1389+
def test_push_to_hub_with_model_card(self):
1390+
config = BertConfig(
1391+
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
1392+
)
1393+
model = TFBertModel(config)
1394+
with tempfile.TemporaryDirectory() as tmp_dir:
1395+
model.push_to_hub(os.path.join(tmp_dir, "test-model-tf"))
1396+
self.assertTrue(os.path.isfile(os.path.join(tmp_dir, "test-model-card-tf", "README.md")))
1397+
13891398
def test_push_to_hub_in_organization(self):
13901399
config = BertConfig(
13911400
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37

0 commit comments

Comments
 (0)