21
21
import os
22
22
import re
23
23
import warnings
24
- from typing import Any , Dict , Optional , Tuple , Union
24
+ from typing import Any , Dict , Tuple , Union
25
25
26
26
from packaging import version
27
27
36
36
RevisionNotFoundError ,
37
37
cached_path ,
38
38
copy_func ,
39
- get_list_of_files ,
40
39
hf_bucket_url ,
41
40
is_offline_mode ,
42
41
is_remote_url ,
46
45
47
46
48
47
logger = logging .get_logger (__name__ )
49
- FULL_CONFIGURATION_FILE = "config.json"
50
48
_re_configuration_file = re .compile (r"config\.(.*)\.json" )
51
49
52
50
@@ -509,6 +507,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
509
507
assert unused_kwargs == {"foo": False}
510
508
```"""
511
509
config_dict , kwargs = cls .get_config_dict (pretrained_model_name_or_path , ** kwargs )
510
+ if "configuration_files" in config_dict :
511
+ # We may have to load another config
512
+ configuration_file = get_configuration_file (config_dict ["configuration_files" ])
513
+ if configuration_file != CONFIG_NAME :
514
+ config_dict , kwargs = cls .get_config_dict (
515
+ pretrained_model_name_or_path , _configuration_file = configuration_file , ** kwargs
516
+ )
517
+
512
518
if "model_type" in config_dict and hasattr (cls , "model_type" ) and config_dict ["model_type" ] != cls .model_type :
513
519
logger .warn (
514
520
f"You are using a model of type { config_dict ['model_type' ]} to instantiate a model of type "
@@ -525,7 +531,12 @@ def get_config_dict(
525
531
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
526
532
[`PretrainedConfig`] using `from_dict`.
527
533
534
+ <Tip warning={true}>
528
535
536
+ This method is for internal use only and will only load the base configuration of the model (when several are
537
+ available). You should always use the [~PretrainedConfig.from_pretrained`] method.
538
+
539
+ </Tip>
529
540
530
541
Parameters:
531
542
pretrained_model_name_or_path (`str` or `os.PathLike`):
@@ -557,13 +568,7 @@ def get_config_dict(
557
568
if os .path .isfile (pretrained_model_name_or_path ) or is_remote_url (pretrained_model_name_or_path ):
558
569
config_file = pretrained_model_name_or_path
559
570
else :
560
- configuration_file = get_configuration_file (
561
- pretrained_model_name_or_path ,
562
- revision = revision ,
563
- use_auth_token = use_auth_token ,
564
- local_files_only = local_files_only ,
565
- )
566
-
571
+ configuration_file = kwargs .pop ("_configuration_file" , CONFIG_NAME )
567
572
if os .path .isdir (pretrained_model_name_or_path ):
568
573
config_file = os .path .join (pretrained_model_name_or_path , configuration_file )
569
574
else :
@@ -841,49 +846,26 @@ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
841
846
d ["torch_dtype" ] = str (d ["torch_dtype" ]).split ("." )[1 ]
842
847
843
848
844
- def get_configuration_file (
845
- path_or_repo : Union [str , os .PathLike ],
846
- revision : Optional [str ] = None ,
847
- use_auth_token : Optional [Union [bool , str ]] = None ,
848
- local_files_only : bool = False ,
849
- ) -> str :
849
+ def get_configuration_file (configuration_files ) -> str :
850
850
"""
851
851
Get the configuration file to use for this version of transformers.
852
852
853
853
Args:
854
- path_or_repo (`str` or `os.PathLike`):
855
- Can be either the id of a repo on huggingface.co or a path to a *directory*.
856
- revision(`str`, *optional*, defaults to `"main"`):
857
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
858
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
859
- identifier allowed by git.
860
- use_auth_token (`str` or *bool*, *optional*):
861
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
862
- when running `transformers-cli login` (stored in `~/.huggingface`).
863
- local_files_only (`bool`, *optional*, defaults to `False`):
864
- Whether or not to only rely on local files and not to attempt to download any files.
854
+ configuration_files (`List[str]`): The list of configuration files to pick from.
865
855
866
856
Returns:
867
857
`str`: The configuration file to use.
868
858
"""
869
- # Inspect all files from the repo/folder.
870
- try :
871
- all_files = get_list_of_files (
872
- path_or_repo , revision = revision , use_auth_token = use_auth_token , local_files_only = local_files_only
873
- )
874
- except Exception :
875
- return FULL_CONFIGURATION_FILE
876
-
877
859
configuration_files_map = {}
878
- for file_name in all_files :
860
+ for file_name in configuration_files :
879
861
search = _re_configuration_file .search (file_name )
880
862
if search is not None :
881
863
v = search .groups ()[0 ]
882
864
configuration_files_map [v ] = file_name
883
865
available_versions = sorted (configuration_files_map .keys ())
884
866
885
- # Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
886
- configuration_file = FULL_CONFIGURATION_FILE
867
+ # Defaults to CONFIG_NAME and then try to look at some newer versions.
868
+ configuration_file = CONFIG_NAME
887
869
transformers_version = version .parse (__version__ )
888
870
for v in available_versions :
889
871
if version .parse (v ) <= transformers_version :
0 commit comments