Skip to content

Commit e695470

Browse files
authored
Avoid using get_list_of_files (huggingface#15287)
* Avoid using get_list_of_files in config * Wip, change tokenizer file getter * Remove call in tokenizer files * Remove last call to get_list_model_files * Better tests * Unit tests for new function * Document bad API
1 parent e65bfc0 commit e695470

File tree

8 files changed

+232
-134
lines changed

8 files changed

+232
-134
lines changed

src/transformers/configuration_utils.py

Lines changed: 24 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import os
2222
import re
2323
import warnings
24-
from typing import Any, Dict, Optional, Tuple, Union
24+
from typing import Any, Dict, List, Tuple, Union
2525

2626
from packaging import version
2727

@@ -36,7 +36,6 @@
3636
RevisionNotFoundError,
3737
cached_path,
3838
copy_func,
39-
get_list_of_files,
4039
hf_bucket_url,
4140
is_offline_mode,
4241
is_remote_url,
@@ -46,7 +45,7 @@
4645

4746

4847
logger = logging.get_logger(__name__)
49-
FULL_CONFIGURATION_FILE = "config.json"
48+
5049
_re_configuration_file = re.compile(r"config\.(.*)\.json")
5150

5251

@@ -533,6 +532,23 @@ def get_config_dict(
533532
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
534533
535534
"""
535+
original_kwargs = copy.deepcopy(kwargs)
536+
# Get config dict associated with the base config file
537+
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
538+
539+
# That config file may point us toward another config file to use.
540+
if "configuration_files" in config_dict:
541+
configuration_file = get_configuration_file(config_dict["configuration_files"])
542+
config_dict, kwargs = cls._get_config_dict(
543+
pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
544+
)
545+
546+
return config_dict, kwargs
547+
548+
@classmethod
549+
def _get_config_dict(
550+
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
551+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
536552
cache_dir = kwargs.pop("cache_dir", None)
537553
force_download = kwargs.pop("force_download", False)
538554
resume_download = kwargs.pop("resume_download", False)
@@ -555,12 +571,7 @@ def get_config_dict(
555571
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
556572
config_file = pretrained_model_name_or_path
557573
else:
558-
configuration_file = get_configuration_file(
559-
pretrained_model_name_or_path,
560-
revision=revision,
561-
use_auth_token=use_auth_token,
562-
local_files_only=local_files_only,
563-
)
574+
configuration_file = kwargs.get("_configuration_file", CONFIG_NAME)
564575

565576
if os.path.isdir(pretrained_model_name_or_path):
566577
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
@@ -840,49 +851,26 @@ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
840851
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
841852

842853

843-
def get_configuration_file(
844-
path_or_repo: Union[str, os.PathLike],
845-
revision: Optional[str] = None,
846-
use_auth_token: Optional[Union[bool, str]] = None,
847-
local_files_only: bool = False,
848-
) -> str:
854+
def get_configuration_file(configuration_files: List[str]) -> str:
849855
"""
850856
Get the configuration file to use for this version of transformers.
851857
852858
Args:
853-
path_or_repo (`str` or `os.PathLike`):
854-
Can be either the id of a repo on huggingface.co or a path to a *directory*.
855-
revision(`str`, *optional*, defaults to `"main"`):
856-
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
857-
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
858-
identifier allowed by git.
859-
use_auth_token (`str` or *bool*, *optional*):
860-
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
861-
when running `transformers-cli login` (stored in `~/.huggingface`).
862-
local_files_only (`bool`, *optional*, defaults to `False`):
863-
Whether or not to only rely on local files and not to attempt to download any files.
859+
configuration_files (`List[str]`): The list of available configuration files.
864860
865861
Returns:
866862
`str`: The configuration file to use.
867863
"""
868-
# Inspect all files from the repo/folder.
869-
try:
870-
all_files = get_list_of_files(
871-
path_or_repo, revision=revision, use_auth_token=use_auth_token, local_files_only=local_files_only
872-
)
873-
except Exception:
874-
return FULL_CONFIGURATION_FILE
875-
876864
configuration_files_map = {}
877-
for file_name in all_files:
865+
for file_name in configuration_files:
878866
search = _re_configuration_file.search(file_name)
879867
if search is not None:
880868
v = search.groups()[0]
881869
configuration_files_map[v] = file_name
882870
available_versions = sorted(configuration_files_map.keys())
883871

884872
# Defaults to FULL_CONFIGURATION_FILE and then try to look at some newer versions.
885-
configuration_file = FULL_CONFIGURATION_FILE
873+
configuration_file = CONFIG_NAME
886874
transformers_version = version.parse(__version__)
887875
for v in available_versions:
888876
if version.parse(v) <= transformers_version:

src/transformers/file_utils.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2112,6 +2112,112 @@ def _resumable_file_manager() -> "io.BufferedWriter":
21122112
return cache_path
21132113

21142114

2115+
def get_file_from_repo(
2116+
path_or_repo: Union[str, os.PathLike],
2117+
filename: str,
2118+
cache_dir: Optional[Union[str, os.PathLike]] = None,
2119+
force_download: bool = False,
2120+
resume_download: bool = False,
2121+
proxies: Optional[Dict[str, str]] = None,
2122+
use_auth_token: Optional[Union[bool, str]] = None,
2123+
revision: Optional[str] = None,
2124+
local_files_only: bool = False,
2125+
):
2126+
"""
2127+
Tries to locate a file in a local folder and repo, downloads and cache it if necessary.
2128+
2129+
Args:
2130+
path_or_repo (`str` or `os.PathLike`):
2131+
This can be either:
2132+
2133+
- a string, the *model id* of a model repo on huggingface.co.
2134+
- a path to a *directory* potentially containing the file.
2135+
filename (`str`):
2136+
The name of the file to locate in `path_or_repo`.
2137+
cache_dir (`str` or `os.PathLike`, *optional*):
2138+
Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
2139+
cache should not be used.
2140+
force_download (`bool`, *optional*, defaults to `False`):
2141+
Whether or not to force to (re-)download the configuration files and override the cached versions if they
2142+
exist.
2143+
resume_download (`bool`, *optional*, defaults to `False`):
2144+
Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
2145+
proxies (`Dict[str, str]`, *optional*):
2146+
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
2147+
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
2148+
use_auth_token (`str` or *bool*, *optional*):
2149+
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
2150+
when running `transformers-cli login` (stored in `~/.huggingface`).
2151+
revision(`str`, *optional*, defaults to `"main"`):
2152+
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
2153+
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
2154+
identifier allowed by git.
2155+
local_files_only (`bool`, *optional*, defaults to `False`):
2156+
If `True`, will only try to load the tokenizer configuration from local files.
2157+
2158+
<Tip>
2159+
2160+
Passing `use_auth_token=True` is required when you want to use a private model.
2161+
2162+
</Tip>
2163+
2164+
Returns:
2165+
`Optional[str]`: Returns the resolved file (to the cache folder if downloaded from a repo) or `None` if the
2166+
file does not exist.
2167+
2168+
Examples:
2169+
2170+
```python
2171+
# Download a tokenizer configuration from huggingface.co and cache.
2172+
tokenizer_config = get_file_from_repo("bert-base-uncased", "tokenizer_config.json")
2173+
# This model does not have a tokenizer config so the result will be None.
2174+
tokenizer_config = get_file_from_repo("xlm-roberta-base", "tokenizer_config.json")
2175+
```"""
2176+
if is_offline_mode() and not local_files_only:
2177+
logger.info("Offline mode: forcing local_files_only=True")
2178+
local_files_only = True
2179+
2180+
path_or_repo = str(path_or_repo)
2181+
if os.path.isdir(path_or_repo):
2182+
resolved_file = os.path.join(path_or_repo, filename)
2183+
return resolved_file if os.path.isfile(resolved_file) else None
2184+
else:
2185+
resolved_file = hf_bucket_url(path_or_repo, filename=filename, revision=revision, mirror=None)
2186+
2187+
try:
2188+
# Load from URL or cache if already cached
2189+
resolved_file = cached_path(
2190+
resolved_file,
2191+
cache_dir=cache_dir,
2192+
force_download=force_download,
2193+
proxies=proxies,
2194+
resume_download=resume_download,
2195+
local_files_only=local_files_only,
2196+
use_auth_token=use_auth_token,
2197+
)
2198+
2199+
except RepositoryNotFoundError as err:
2200+
logger.error(err)
2201+
raise EnvironmentError(
2202+
f"{path_or_repo} is not a local folder and is not a valid model identifier "
2203+
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
2204+
"pass a token having permission to this repo with `use_auth_token` or log in with "
2205+
"`huggingface-cli login` and pass `use_auth_token=True`."
2206+
)
2207+
except RevisionNotFoundError as err:
2208+
logger.error(err)
2209+
raise EnvironmentError(
2210+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
2211+
"for this model name. Check the model page at "
2212+
f"'https://huggingface.co/{path_or_repo}' for available revisions."
2213+
)
2214+
except EnvironmentError:
2215+
# The repo and revision exist, but the file does not or there was a connection error fetching it.
2216+
return None
2217+
2218+
return resolved_file
2219+
2220+
21152221
def has_file(
21162222
path_or_repo: Union[str, os.PathLike],
21172223
filename: str,
@@ -2184,6 +2290,12 @@ def get_list_of_files(
21842290
local_files_only (`bool`, *optional*, defaults to `False`):
21852291
Whether or not to only rely on local files and not to attempt to download any files.
21862292
2293+
<Tip warning={true}>
2294+
2295+
This API is not optimized, so calling it a lot may result in connection errors.
2296+
2297+
</Tip>
2298+
21872299
Returns:
21882300
`List[str]`: The list of files available in `path_or_repo`.
21892301
"""

src/transformers/models/auto/processing_auto.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,14 @@
1414
# limitations under the License.
1515
""" AutoProcessor class."""
1616
import importlib
17+
import inspect
18+
import json
1719
from collections import OrderedDict
1820

1921
# Build the list of all feature extractors
2022
from ...configuration_utils import PretrainedConfig
2123
from ...feature_extraction_utils import FeatureExtractionMixin
22-
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_list_of_files
24+
from ...file_utils import CONFIG_NAME, FEATURE_EXTRACTOR_NAME, get_file_from_repo
2325
from ...tokenization_utils import TOKENIZER_CONFIG_FILE
2426
from .auto_factory import _LazyAutoMapping
2527
from .configuration_auto import (
@@ -29,7 +31,6 @@
2931
model_type_to_module_name,
3032
replace_list_option_in_docstrings,
3133
)
32-
from .tokenization_auto import get_tokenizer_config
3334

3435

3536
PROCESSOR_MAPPING_NAMES = OrderedDict(
@@ -145,24 +146,29 @@ def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
145146
kwargs["_from_auto"] = True
146147

147148
# First, let's see if we have a preprocessor config.
148-
# get_list_of_files only takes three of the kwargs we have, so we filter them.
149-
get_list_of_files_kwargs = {
150-
key: kwargs[key] for key in ["revision", "use_auth_token", "local_files_only"] if key in kwargs
149+
# Filter the kwargs for `get_file_from_repo``.
150+
get_file_from_repo_kwargs = {
151+
key: kwargs[key] for key in inspect.signature(get_file_from_repo).parameters.keys() if key in kwargs
151152
}
152-
model_files = get_list_of_files(pretrained_model_name_or_path, **get_list_of_files_kwargs)
153-
# strip to file name
154-
model_files = [f.split("/")[-1] for f in model_files]
155-
156153
# Let's start by checking whether the processor class is saved in a feature extractor
157-
if FEATURE_EXTRACTOR_NAME in model_files:
154+
preprocessor_config_file = get_file_from_repo(
155+
pretrained_model_name_or_path, FEATURE_EXTRACTOR_NAME, **get_file_from_repo_kwargs
156+
)
157+
if preprocessor_config_file is not None:
158158
config_dict, _ = FeatureExtractionMixin.get_feature_extractor_dict(pretrained_model_name_or_path, **kwargs)
159159
if "processor_class" in config_dict:
160160
processor_class = processor_class_from_name(config_dict["processor_class"])
161161
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
162162

163163
# Next, let's check whether the processor class is saved in a tokenizer
164-
if TOKENIZER_CONFIG_FILE in model_files:
165-
config_dict = get_tokenizer_config(pretrained_model_name_or_path, **kwargs)
164+
# Let's start by checking whether the processor class is saved in a feature extractor
165+
tokenizer_config_file = get_file_from_repo(
166+
pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE, **get_file_from_repo_kwargs
167+
)
168+
if tokenizer_config_file is not None:
169+
with open(tokenizer_config_file, encoding="utf-8") as reader:
170+
config_dict = json.load(reader)
171+
166172
if "processor_class" in config_dict:
167173
processor_class = processor_class_from_name(config_dict["processor_class"])
168174
return processor_class.from_pretrained(pretrained_model_name_or_path, **kwargs)

src/transformers/models/auto/tokenization_auto.py

Lines changed: 13 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,7 @@
2121
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Union
2222

2323
from ...configuration_utils import PretrainedConfig
24-
from ...file_utils import (
25-
RepositoryNotFoundError,
26-
RevisionNotFoundError,
27-
cached_path,
28-
hf_bucket_url,
29-
is_offline_mode,
30-
is_sentencepiece_available,
31-
is_tokenizers_available,
32-
)
24+
from ...file_utils import get_file_from_repo, is_sentencepiece_available, is_tokenizers_available
3325
from ...tokenization_utils import PreTrainedTokenizer
3426
from ...tokenization_utils_base import TOKENIZER_CONFIG_FILE
3527
from ...tokenization_utils_fast import PreTrainedTokenizerFast
@@ -329,46 +321,18 @@ def get_tokenizer_config(
329321
tokenizer.save_pretrained("tokenizer-test")
330322
tokenizer_config = get_tokenizer_config("tokenizer-test")
331323
```"""
332-
if is_offline_mode() and not local_files_only:
333-
logger.info("Offline mode: forcing local_files_only=True")
334-
local_files_only = True
335-
336-
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
337-
if os.path.isdir(pretrained_model_name_or_path):
338-
config_file = os.path.join(pretrained_model_name_or_path, TOKENIZER_CONFIG_FILE)
339-
else:
340-
config_file = hf_bucket_url(
341-
pretrained_model_name_or_path, filename=TOKENIZER_CONFIG_FILE, revision=revision, mirror=None
342-
)
343-
344-
try:
345-
# Load from URL or cache if already cached
346-
resolved_config_file = cached_path(
347-
config_file,
348-
cache_dir=cache_dir,
349-
force_download=force_download,
350-
proxies=proxies,
351-
resume_download=resume_download,
352-
local_files_only=local_files_only,
353-
use_auth_token=use_auth_token,
354-
)
355-
356-
except RepositoryNotFoundError as err:
357-
logger.error(err)
358-
raise EnvironmentError(
359-
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
360-
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to "
361-
"pass a token having permission to this repo with `use_auth_token` or log in with "
362-
"`huggingface-cli login` and pass `use_auth_token=True`."
363-
)
364-
except RevisionNotFoundError as err:
365-
logger.error(err)
366-
raise EnvironmentError(
367-
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists "
368-
"for this model name. Check the model page at "
369-
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
370-
)
371-
except EnvironmentError:
324+
resolved_config_file = get_file_from_repo(
325+
pretrained_model_name_or_path,
326+
TOKENIZER_CONFIG_FILE,
327+
cache_dir=cache_dir,
328+
force_download=force_download,
329+
resume_download=resume_download,
330+
proxies=proxies,
331+
use_auth_token=use_auth_token,
332+
revision=revision,
333+
local_files_only=local_files_only,
334+
)
335+
if resolved_config_file is None:
372336
logger.info("Could not locate the tokenizer configuration file, will try to use the model config instead.")
373337
return {}
374338

0 commit comments

Comments
 (0)