Skip to content

Commit cb93b7c

Browse files
committed
PoC to avoid using get_list_of_files
1 parent 3d66f3b commit cb93b7c

File tree

1 file changed

+20
-38
lines changed

1 file changed

+20
-38
lines changed

src/transformers/configuration_utils.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import os
2222
import re
2323
import warnings
24-
from typing import Any, Dict, Optional, Tuple, Union
24+
from typing import Any, Dict, Tuple, Union
2525

2626
from packaging import version
2727

@@ -36,7 +36,6 @@
3636
RevisionNotFoundError,
3737
cached_path,
3838
copy_func,
39-
get_list_of_files,
4039
hf_bucket_url,
4140
is_offline_mode,
4241
is_remote_url,
@@ -46,7 +45,6 @@
4645

4746

4847
logger = logging.get_logger(__name__)
49-
FULL_CONFIGURATION_FILE = "config.json"
5048
_re_configuration_file = re.compile(r"config\.(.*)\.json")
5149

5250

@@ -509,6 +507,14 @@ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike],
509507
assert unused_kwargs == {"foo": False}
510508
```"""
511509
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
510+
if "configuration_files" in config_dict:
511+
# We may have to load another config
512+
configuration_file = get_configuration_file(config_dict["configuration_files"])
513+
if configuration_file != CONFIG_NAME:
514+
config_dict, kwargs = cls.get_config_dict(
515+
pretrained_model_name_or_path, _configuration_file=configuration_file, **kwargs
516+
)
517+
512518
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
513519
logger.warn(
514520
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
@@ -525,7 +531,12 @@ def get_config_dict(
525531
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
526532
[`PretrainedConfig`] using `from_dict`.
527533
534+
<Tip warning={true}>
528535
536+
This method is for internal use only and will only load the base configuration of the model (when several are
537+
available). You should always use the [~PretrainedConfig.from_pretrained`] method.
538+
539+
</Tip>
529540
530541
Parameters:
531542
pretrained_model_name_or_path (`str` or `os.PathLike`):
@@ -557,13 +568,7 @@ def get_config_dict(
557568
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
558569
config_file = pretrained_model_name_or_path
559570
else:
560-
configuration_file = get_configuration_file(
561-
pretrained_model_name_or_path,
562-
revision=revision,
563-
use_auth_token=use_auth_token,
564-
local_files_only=local_files_only,
565-
)
566-
571+
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
567572
if os.path.isdir(pretrained_model_name_or_path):
568573
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
569574
else:
@@ -841,49 +846,26 @@ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
841846
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
842847

843848

844-
def get_configuration_file(
845-
path_or_repo: Union[str, os.PathLike],
846-
revision: Optional[str] = None,
847-
use_auth_token: Optional[Union[bool, str]] = None,
848-
local_files_only: bool = False,
849-
) -> str:
849+
def get_configuration_file(configuration_files) -> str:
850850
"""
851851
Get the configuration file to use for this version of transformers.
852852
853853
Args:
854-
path_or_repo (`str` or `os.PathLike`):
855-
Can be either the id of a repo on huggingface.co or a path to a *directory*.
856-
revision(`str`, *optional*, defaults to `"main"`):
857-
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
858-
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
859-
identifier allowed by git.
860-
use_auth_token (`str` or *bool*, *optional*):
861-
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
862-
when running `transformers-cli login` (stored in `~/.huggingface`).
863-
local_files_only (`bool`, *optional*, defaults to `False`):
864-
Whether or not to only rely on local files and not to attempt to download any files.
854+
configuration_files (`List[str]`): The list of configuration files to pick from.
865855
866856
Returns:
867857
`str`: The configuration file to use.
868858
"""
869-
# Inspect all files from the repo/folder.
870-
try:
871-
all_files = get_list_of_files(
872-
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
873-
)
874-
except Exception:
875-
return FULL_CONFIGURATION_FILE
876-
877859
configuration_files_map = {}
878-
for file_name in all_files:
860+
for file_name in configuration_files:
879861
search = _re_configuration_file.search(file_name)
880862
if search is not None:
881863
v = search.groups()[0]
882864
configuration_files_map[v] = file_name
883865
available_versions = sorted(configuration_files_map.keys())
884866

885-
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
886-
configuration_file = FULL_CONFIGURATION_FILE
867+
# Defaults to CONFIG_NAME and then try to look at some newer versions.
868+
configuration_file = CONFIG_NAME
887869
transformers_version = version.parse(__version__)
888870
for v in available_versions:
889871
if version.parse(v) <= transformers_version:

0 commit comments

Comments
 (0)