Skip to content

Commit e84786a

Browse files
authored
consistent ignore keys + make private (huggingface#8737)
* consistent ignore keys + make private * style * - authorized_missing_keys => _keys_to_ignore_on_load_missing - authorized_unexpected_keys => _keys_to_ignore_on_load_unexpected * move public doc of private attributes to private comment
1 parent 49759c0 commit e84786a

38 files changed

+127
-126
lines changed

src/transformers/modeling_tf_pytorch_utils.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,9 +164,9 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
164164
if allow_missing_keys:
165165
missing_keys.append(name)
166166
continue
167-
elif tf_model.authorized_missing_keys is not None:
167+
elif tf_model._keys_to_ignore_on_load_missing is not None:
168168
# authorized missing keys don't have to be loaded
169-
if any(re.search(pat, name) is not None for pat in tf_model.authorized_missing_keys):
169+
if any(re.search(pat, name) is not None for pat in tf_model._keys_to_ignore_on_load_missing):
170170
continue
171171

172172
raise AttributeError("{} not found in PyTorch model".format(name))
@@ -209,11 +209,11 @@ def load_pytorch_weights_in_tf2_model(tf_model, pt_state_dict, tf_inputs=None, a
209209

210210
unexpected_keys = list(all_pytorch_weights)
211211

212-
if tf_model.authorized_missing_keys is not None:
213-
for pat in tf_model.authorized_missing_keys:
212+
if tf_model._keys_to_ignore_on_load_missing is not None:
213+
for pat in tf_model._keys_to_ignore_on_load_missing:
214214
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
215-
if tf_model.authorized_unexpected_keys is not None:
216-
for pat in tf_model.authorized_unexpected_keys:
215+
if tf_model._keys_to_ignore_on_load_unexpected is not None:
216+
for pat in tf_model._keys_to_ignore_on_load_unexpected:
217217
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
218218

219219
if len(unexpected_keys) > 0:

src/transformers/modeling_tf_utils.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -343,15 +343,15 @@ class TFPreTrainedModel(tf.keras.Model, TFModelUtilsMixin, TFGenerationMixin):
343343
:class:`~transformers.PretrainedConfig` to use as configuration class for this model architecture.
344344
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
345345
derived classes of the same architecture adding modules on top of the base model.
346-
- **authorized_missing_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to ignore
347-
from the model when loading the model weights (and avoid unnecessary warnings).
348-
- **authorized_unexpected_keys** (:obj:`List[str]`, `optional`) -- A list of re pattern of tensor names to
349-
ignore from the weights when loading the model weights (and avoid unnecessary warnings).
350346
"""
351347
config_class = None
352348
base_model_prefix = ""
353-
authorized_missing_keys = None
354-
authorized_unexpected_keys = None
349+
# a list of re pattern of tensor names to ignore from the model when loading the model weights
350+
# (and avoid unnecessary warnings).
351+
_keys_to_ignore_on_load_missing = None
352+
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
353+
# (and avoid unnecessary warnings).
354+
_keys_to_ignore_on_load_unexpected = None
355355

356356
@property
357357
def dummy_inputs(self) -> Dict[str, tf.Tensor]:
@@ -742,12 +742,12 @@ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
742742

743743
model(model.dummy_inputs, training=False) # Make sure restore ops are run
744744

745-
if cls.authorized_missing_keys is not None:
746-
for pat in cls.authorized_missing_keys:
745+
if cls._keys_to_ignore_on_load_missing is not None:
746+
for pat in cls._keys_to_ignore_on_load_missing:
747747
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
748748

749-
if cls.authorized_unexpected_keys is not None:
750-
for pat in cls.authorized_unexpected_keys:
749+
if cls._keys_to_ignore_on_load_unexpected is not None:
750+
for pat in cls._keys_to_ignore_on_load_unexpected:
751751
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
752752

753753
if len(unexpected_keys) > 0:

src/transformers/modeling_utils.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -404,17 +404,18 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin):
404404
405405
- **base_model_prefix** (:obj:`str`) -- A string indicating the attribute associated to the base model in
406406
derived classes of the same architecture adding modules on top of the base model.
407-
- **authorized_missing_keys** (:obj:`Optional[List[str]]`) -- A list of re pattern of tensor names to ignore
408-
when loading the model (and avoid unnecessary warnings).
409-
- **keys_to_never_save** (:obj:`Optional[List[str]]`) -- A list of of tensor names to ignore when saving the
410-
model (useful for keys that aren't trained, but which are deterministic)
411-
412407
"""
413408
config_class = None
414409
base_model_prefix = ""
415-
authorized_missing_keys = None
416-
authorized_unexpected_keys = None
417-
keys_to_never_save = None
410+
# a list of re pattern of tensor names to ignore from the model when loading the model weights
411+
# (and avoid unnecessary warnings).
412+
_keys_to_ignore_on_load_missing = None
413+
# a list of re pattern of tensor names to ignore from the weights when loading the model weights
414+
# (and avoid unnecessary warnings).
415+
_keys_to_ignore_on_load_unexpected = None
416+
# a list of of tensor names to ignore when saving the model (useful for keys that aren't
417+
# trained, but which are deterministic)
418+
_keys_to_ignore_on_save = None
418419

419420
@property
420421
def dummy_inputs(self) -> Dict[str, torch.Tensor]:
@@ -719,8 +720,8 @@ def save_pretrained(self, save_directory):
719720
state_dict = model_to_save.state_dict()
720721

721722
# Handle the case where some state_dict keys shouldn't be saved
722-
if self.keys_to_never_save is not None:
723-
state_dict = {k: v for k, v in state_dict.items() if k not in self.keys_to_never_save}
723+
if self._keys_to_ignore_on_save is not None:
724+
state_dict = {k: v for k, v in state_dict.items() if k not in self._keys_to_ignore_on_save}
724725

725726
# If we save using the predefined names, we can load using `from_pretrained`
726727
output_model_file = os.path.join(save_directory, WEIGHTS_NAME)
@@ -1034,12 +1035,12 @@ def load(module: nn.Module, prefix=""):
10341035

10351036
# Some models may have keys that are not in the state by design, removing them before needlessly warning
10361037
# the user.
1037-
if cls.authorized_missing_keys is not None:
1038-
for pat in cls.authorized_missing_keys:
1038+
if cls._keys_to_ignore_on_load_missing is not None:
1039+
for pat in cls._keys_to_ignore_on_load_missing:
10391040
missing_keys = [k for k in missing_keys if re.search(pat, k) is None]
10401041

1041-
if cls.authorized_unexpected_keys is not None:
1042-
for pat in cls.authorized_unexpected_keys:
1042+
if cls._keys_to_ignore_on_load_unexpected is not None:
1043+
for pat in cls._keys_to_ignore_on_load_unexpected:
10431044
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
10441045

10451046
if len(unexpected_keys) > 0:

src/transformers/models/albert/modeling_albert.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,7 +459,7 @@ class AlbertPreTrainedModel(PreTrainedModel):
459459

460460
config_class = AlbertConfig
461461
base_model_prefix = "albert"
462-
authorized_missing_keys = [r"position_ids"]
462+
_keys_to_ignore_on_load_missing = [r"position_ids"]
463463

464464
def _init_weights(self, module):
465465
"""Initialize the weights."""
@@ -851,7 +851,7 @@ def forward(self, pooled_output):
851851
)
852852
class AlbertForMaskedLM(AlbertPreTrainedModel):
853853

854-
authorized_unexpected_keys = [r"pooler"]
854+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
855855

856856
def __init__(self, config):
857857
super().__init__(config)
@@ -1021,7 +1021,7 @@ def forward(
10211021
)
10221022
class AlbertForTokenClassification(AlbertPreTrainedModel):
10231023

1024-
authorized_unexpected_keys = [r"pooler"]
1024+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
10251025

10261026
def __init__(self, config):
10271027
super().__init__(config)
@@ -1110,7 +1110,7 @@ def forward(
11101110
)
11111111
class AlbertForQuestionAnswering(AlbertPreTrainedModel):
11121112

1113-
authorized_unexpected_keys = [r"pooler"]
1113+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
11141114

11151115
def __init__(self, config):
11161116
super().__init__(config)

src/transformers/models/albert/modeling_tf_albert.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ def call(self, pooled_output, training: bool):
843843
@add_start_docstrings("""Albert Model with a `language modeling` head on top. """, ALBERT_START_DOCSTRING)
844844
class TFAlbertForMaskedLM(TFAlbertPreTrainedModel, TFMaskedLanguageModelingLoss):
845845

846-
authorized_missing_keys = [r"pooler"]
846+
_keys_to_ignore_on_load_missing = [r"pooler"]
847847

848848
def __init__(self, config, *inputs, **kwargs):
849849
super().__init__(config, *inputs, **kwargs)
@@ -1013,7 +1013,7 @@ def call(
10131013
)
10141014
class TFAlbertForTokenClassification(TFAlbertPreTrainedModel, TFTokenClassificationLoss):
10151015

1016-
authorized_missing_keys = [r"pooler"]
1016+
_keys_to_ignore_on_load_missing = [r"pooler"]
10171017

10181018
def __init__(self, config, *inputs, **kwargs):
10191019
super().__init__(config, *inputs, **kwargs)
@@ -1100,7 +1100,7 @@ def call(
11001100
)
11011101
class TFAlbertForQuestionAnswering(TFAlbertPreTrainedModel, TFQuestionAnsweringLoss):
11021102

1103-
authorized_missing_keys = [r"pooler"]
1103+
_keys_to_ignore_on_load_missing = [r"pooler"]
11041104

11051105
def __init__(self, config, *inputs, **kwargs):
11061106
super().__init__(config, *inputs, **kwargs)

src/transformers/models/bart/modeling_bart.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -946,7 +946,7 @@ def get_output_embeddings(self):
946946
)
947947
class BartForConditionalGeneration(PretrainedBartModel):
948948
base_model_prefix = "model"
949-
authorized_missing_keys = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
949+
_keys_to_ignore_on_load_missing = [r"final_logits_bias", r"encoder\.version", r"decoder\.version"]
950950

951951
def __init__(self, config: BartConfig):
952952
super().__init__(config)

src/transformers/models/bart/modeling_tf_bart.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1020,10 +1020,10 @@ def get_output_embeddings(self):
10201020
)
10211021
class TFBartForConditionalGeneration(TFPretrainedBartModel):
10221022
base_model_prefix = "model"
1023-
authorized_missing_keys = [
1023+
_keys_to_ignore_on_load_missing = [
10241024
r"final_logits_bias",
10251025
]
1026-
authorized_unexpected_keys = [
1026+
_keys_to_ignore_on_load_unexpected = [
10271027
r"model.encoder.embed_tokens.weight",
10281028
r"model.decoder.embed_tokens.weight",
10291029
]

src/transformers/models/bert/modeling_bert.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -598,7 +598,7 @@ class BertPreTrainedModel(PreTrainedModel):
598598
config_class = BertConfig
599599
load_tf_weights = load_tf_weights_in_bert
600600
base_model_prefix = "bert"
601-
authorized_missing_keys = [r"position_ids"]
601+
_keys_to_ignore_on_load_missing = [r"position_ids"]
602602

603603
def _init_weights(self, module):
604604
""" Initialize the weights """
@@ -969,8 +969,8 @@ def forward(
969969
)
970970
class BertLMHeadModel(BertPreTrainedModel):
971971

972-
authorized_unexpected_keys = [r"pooler"]
973-
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
972+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
973+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
974974

975975
def __init__(self, config):
976976
super().__init__(config)
@@ -1087,8 +1087,8 @@ def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_
10871087
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
10881088
class BertForMaskedLM(BertPreTrainedModel):
10891089

1090-
authorized_unexpected_keys = [r"pooler"]
1091-
authorized_missing_keys = [r"position_ids", r"predictions.decoder.bias"]
1090+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
1091+
_keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"]
10921092

10931093
def __init__(self, config):
10941094
super().__init__(config)
@@ -1469,7 +1469,7 @@ def forward(
14691469
)
14701470
class BertForTokenClassification(BertPreTrainedModel):
14711471

1472-
authorized_unexpected_keys = [r"pooler"]
1472+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
14731473

14741474
def __init__(self, config):
14751475
super().__init__(config)
@@ -1560,7 +1560,7 @@ def forward(
15601560
)
15611561
class BertForQuestionAnswering(BertPreTrainedModel):
15621562

1563-
authorized_unexpected_keys = [r"pooler"]
1563+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
15641564

15651565
def __init__(self, config):
15661566
super().__init__(config)

src/transformers/models/bert/modeling_tf_bert.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -938,8 +938,8 @@ def call(
938938
@add_start_docstrings("""Bert Model with a `language modeling` head on top. """, BERT_START_DOCSTRING)
939939
class TFBertForMaskedLM(TFBertPreTrainedModel, TFMaskedLanguageModelingLoss):
940940

941-
authorized_unexpected_keys = [r"pooler"]
942-
authorized_missing_keys = [r"pooler"]
941+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
942+
_keys_to_ignore_on_load_missing = [r"pooler"]
943943

944944
def __init__(self, config, *inputs, **kwargs):
945945
super().__init__(config, *inputs, **kwargs)
@@ -1023,8 +1023,8 @@ def call(
10231023

10241024
class TFBertLMHeadModel(TFBertPreTrainedModel, TFCausalLanguageModelingLoss):
10251025

1026-
authorized_unexpected_keys = [r"pooler"]
1027-
authorized_missing_keys = [r"pooler"]
1026+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
1027+
_keys_to_ignore_on_load_missing = [r"pooler"]
10281028

10291029
def __init__(self, config, *inputs, **kwargs):
10301030
super().__init__(config, *inputs, **kwargs)
@@ -1416,8 +1416,8 @@ def call(
14161416
)
14171417
class TFBertForTokenClassification(TFBertPreTrainedModel, TFTokenClassificationLoss):
14181418

1419-
authorized_unexpected_keys = [r"pooler"]
1420-
authorized_missing_keys = [r"pooler"]
1419+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
1420+
_keys_to_ignore_on_load_missing = [r"pooler"]
14211421

14221422
def __init__(self, config, *inputs, **kwargs):
14231423
super().__init__(config, *inputs, **kwargs)
@@ -1502,8 +1502,8 @@ def call(
15021502
)
15031503
class TFBertForQuestionAnswering(TFBertPreTrainedModel, TFQuestionAnsweringLoss):
15041504

1505-
authorized_unexpected_keys = [r"pooler"]
1506-
authorized_missing_keys = [r"pooler"]
1505+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
1506+
_keys_to_ignore_on_load_missing = [r"pooler"]
15071507

15081508
def __init__(self, config, *inputs, **kwargs):
15091509
super().__init__(config, *inputs, **kwargs)

src/transformers/models/bert_generation/modeling_bert_generation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ class BertGenerationPreTrainedModel(PreTrainedModel):
173173

174174
config_class = BertGenerationConfig
175175
base_model_prefix = "bert"
176-
authorized_missing_keys = [r"position_ids"]
176+
_keys_to_ignore_on_load_missing = [r"position_ids"]
177177

178178
def _init_weights(self, module):
179179
""" Initialize the weights """

src/transformers/models/deberta/modeling_deberta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -756,7 +756,7 @@ class DebertaPreTrainedModel(PreTrainedModel):
756756

757757
config_class = DebertaConfig
758758
base_model_prefix = "deberta"
759-
authorized_missing_keys = ["position_ids"]
759+
_keys_to_ignore_on_load_missing = ["position_ids"]
760760

761761
def _init_weights(self, module):
762762
""" Initialize the weights """

src/transformers/models/dpr/modeling_dpr.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,7 @@ class DPRPretrainedContextEncoder(PreTrainedModel):
279279
config_class = DPRConfig
280280
load_tf_weights = None
281281
base_model_prefix = "ctx_encoder"
282-
authorized_missing_keys = [r"position_ids"]
282+
_keys_to_ignore_on_load_missing = [r"position_ids"]
283283

284284
def init_weights(self):
285285
self.ctx_encoder.init_weights()
@@ -294,7 +294,7 @@ class DPRPretrainedQuestionEncoder(PreTrainedModel):
294294
config_class = DPRConfig
295295
load_tf_weights = None
296296
base_model_prefix = "question_encoder"
297-
authorized_missing_keys = [r"position_ids"]
297+
_keys_to_ignore_on_load_missing = [r"position_ids"]
298298

299299
def init_weights(self):
300300
self.question_encoder.init_weights()
@@ -309,7 +309,7 @@ class DPRPretrainedReader(PreTrainedModel):
309309
config_class = DPRConfig
310310
load_tf_weights = None
311311
base_model_prefix = "span_predictor"
312-
authorized_missing_keys = [r"position_ids"]
312+
_keys_to_ignore_on_load_missing = [r"position_ids"]
313313

314314
def init_weights(self):
315315
self.span_predictor.encoder.init_weights()

src/transformers/models/electra/modeling_electra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -544,8 +544,8 @@ class ElectraPreTrainedModel(PreTrainedModel):
544544
config_class = ElectraConfig
545545
load_tf_weights = load_tf_weights_in_electra
546546
base_model_prefix = "electra"
547-
authorized_missing_keys = [r"position_ids"]
548-
authorized_unexpected_keys = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
547+
_keys_to_ignore_on_load_missing = [r"position_ids"]
548+
_keys_to_ignore_on_load_unexpected = [r"electra\.embeddings_project\.weight", r"electra\.embeddings_project\.bias"]
549549

550550
# Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights
551551
def _init_weights(self, module):

src/transformers/models/fsmt/modeling_fsmt.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,11 +1005,11 @@ def set_output_embeddings(self, value):
10051005
)
10061006
class FSMTForConditionalGeneration(PretrainedFSMTModel):
10071007
base_model_prefix = "model"
1008-
authorized_missing_keys = [
1008+
_keys_to_ignore_on_load_missing = [
10091009
"model.encoder.embed_positions.weight",
10101010
"model.decoder.embed_positions.weight",
10111011
]
1012-
keys_to_never_save = [
1012+
_keys_to_ignore_on_save = [
10131013
"model.encoder.embed_positions.weight",
10141014
"model.decoder.embed_positions.weight",
10151015
]

src/transformers/models/gpt2/modeling_gpt2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -780,7 +780,7 @@ def custom_forward(*inputs):
780780
GPT2_START_DOCSTRING,
781781
)
782782
class GPT2LMHeadModel(GPT2PreTrainedModel):
783-
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
783+
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
784784

785785
def __init__(self, config):
786786
super().__init__(config)
@@ -1097,7 +1097,7 @@ def forward(
10971097
GPT2_START_DOCSTRING,
10981098
)
10991099
class GPT2ForSequenceClassification(GPT2PreTrainedModel):
1100-
authorized_missing_keys = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
1100+
_keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"lm_head\.weight"]
11011101

11021102
def __init__(self, config):
11031103
super().__init__(config)

src/transformers/models/layoutlm/modeling_layoutlm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ class LayoutLMPreTrainedModel(PreTrainedModel):
509509

510510
config_class = LayoutLMConfig
511511
base_model_prefix = "layoutlm"
512-
authorized_missing_keys = [r"position_ids"]
512+
_keys_to_ignore_on_load_missing = [r"position_ids"]
513513

514514
def _init_weights(self, module):
515515
""" Initialize the weights """

0 commit comments

Comments
 (0)