Skip to content

Commit 0d5c5c3

Browse files
authored
Merge pull request huggingface#1628 from huggingface/focalnet_and_swin_refactor
Add FocalNet arch, refactor Swin V1/V2 for better feature extraction and HF hub multi-weight support
2 parents cd3ee78 + fafac33 commit 0d5c5c3

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

45 files changed

+57847
-1147
lines changed

tests/test_models.py

Lines changed: 166 additions & 152 deletions
Large diffs are not rendered by default.

timm/data/_info/imagenet22k_ms_synsets.txt

Lines changed: 21841 additions & 0 deletions
Large diffs are not rendered by default.

timm/data/_info/imagenet22k_ms_to_12k_indices.txt

Lines changed: 11821 additions & 0 deletions
Large diffs are not rendered by default.

timm/data/_info/imagenet22k_ms_to_22k_indices.txt

Lines changed: 21841 additions & 0 deletions
Large diffs are not rendered by default.

timm/data/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
root,
8989
reader=None,
9090
split='train',
91+
class_map=None,
9192
is_training=False,
9293
batch_size=None,
9394
seed=42,
@@ -102,6 +103,7 @@ def __init__(
102103
reader,
103104
root=root,
104105
split=split,
106+
class_map=class_map,
105107
is_training=is_training,
106108
batch_size=batch_size,
107109
seed=seed,

timm/data/dataset_factory.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ def create_dataset(
157157
root,
158158
reader=name,
159159
split=split,
160+
class_map=class_map,
160161
is_training=is_training,
161162
download=download,
162163
batch_size=batch_size,
@@ -169,6 +170,7 @@ def create_dataset(
169170
root,
170171
reader=name,
171172
split=split,
173+
class_map=class_map,
172174
is_training=is_training,
173175
batch_size=batch_size,
174176
repeats=repeats,

timm/data/imagenet_info.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,14 @@
77
from .dataset_info import DatasetInfo
88

99

10+
# NOTE no ambiguity wrt to mapping from # classes to ImageNet subset so far, but likely to change
1011
_NUM_CLASSES_TO_SUBSET = {
1112
1000: 'imagenet-1k',
12-
11821: 'imagenet-12k',
13-
21841: 'imagenet-22k',
14-
21843: 'imagenet-21k-goog',
15-
11221: 'imagenet-21k-miil',
13+
11221: 'imagenet-21k-miil', # miil subset of fall11
14+
11821: 'imagenet-12k', # timm specific 12k subset of fall11
15+
21841: 'imagenet-22k', # as in fall11.tar
16+
21842: 'imagenet-22k-ms', # a Microsoft (for FocalNet) remapping of 22k w/ moves ImageNet-1k classes to first 1000
17+
21843: 'imagenet-21k-goog', # Google's ImageNet full has two classes not in fall11
1618
}
1719

1820
_SUBSETS = {
@@ -22,6 +24,7 @@
2224
'imagenet21k': 'imagenet21k_goog_synsets.txt',
2325
'imagenet21kgoog': 'imagenet21k_goog_synsets.txt',
2426
'imagenet21kmiil': 'imagenet21k_miil_synsets.txt',
27+
'imagenet22kms': 'imagenet22k_ms_synsets.txt',
2528
}
2629
_LEMMA_FILE = 'imagenet_synset_to_lemma.txt'
2730
_DEFINITION_FILE = 'imagenet_synset_to_definition.txt'

timm/data/readers/reader_tfds.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
print("Please install tensorflow_datasets package `pip install tensorflow-datasets`.")
3535
exit(1)
3636

37+
from .class_map import load_class_map
3738
from .reader import Reader
3839
from .shared_count import SharedCount
3940

@@ -94,6 +95,7 @@ def __init__(
9495
root,
9596
name,
9697
split='train',
98+
class_map=None,
9799
is_training=False,
98100
batch_size=None,
99101
download=False,
@@ -151,7 +153,12 @@ def __init__(
151153
# NOTE: the tfds command line app can be used download & prepare datasets if you don't enable download flag
152154
if download:
153155
self.builder.download_and_prepare()
154-
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
156+
self.remap_class = False
157+
if class_map:
158+
self.class_to_idx = load_class_map(class_map)
159+
self.remap_class = True
160+
else:
161+
self.class_to_idx = get_class_labels(self.builder.info) if self.target_name == 'label' else {}
155162
self.split_info = self.builder.info.splits[split]
156163
self.num_samples = self.split_info.num_examples
157164

@@ -299,6 +306,8 @@ def __iter__(self):
299306
target_data = sample[self.target_name]
300307
if self.target_img_mode:
301308
target_data = Image.fromarray(target_data, mode=self.target_img_mode)
309+
elif self.remap_class:
310+
target_data = self.class_to_idx[target_data]
302311
yield input_data, target_data
303312
sample_count += 1
304313
if self.is_training and sample_count >= target_sample_count:

timm/data/readers/reader_wds.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
wds = None
3030
expand_urls = None
3131

32+
from .class_map import load_class_map
3233
from .reader import Reader
3334
from .shared_count import SharedCount
3435

@@ -42,13 +43,13 @@ def _load_info(root, basename='info'):
4243
info_yaml = os.path.join(root, basename + '.yaml')
4344
err_str = ''
4445
try:
45-
with wds.gopen.gopen(info_json) as f:
46+
with wds.gopen(info_json) as f:
4647
info_dict = json.load(f)
4748
return info_dict
4849
except Exception as e:
4950
err_str = str(e)
5051
try:
51-
with wds.gopen.gopen(info_yaml) as f:
52+
with wds.gopen(info_yaml) as f:
5253
info_dict = yaml.safe_load(f)
5354
return info_dict
5455
except Exception:
@@ -110,8 +111,8 @@ def _info_convert(dict_info):
110111
filenames=split_filenames,
111112
)
112113
else:
113-
if split not in info['splits']:
114-
raise RuntimeError(f"split {split} not found in info ({info['splits'].keys()})")
114+
if 'splits' not in info or split not in info['splits']:
115+
raise RuntimeError(f"split {split} not found in info ({info.get('splits', {}).keys()})")
115116
split = split
116117
split_info = info['splits'][split]
117118
split_info = _info_convert(split_info)
@@ -290,6 +291,7 @@ def __init__(
290291
batch_size=None,
291292
repeats=0,
292293
seed=42,
294+
class_map=None,
293295
input_name='jpg',
294296
input_image='RGB',
295297
target_name='cls',
@@ -320,6 +322,12 @@ def __init__(
320322
self.num_samples = self.split_info.num_samples
321323
if not self.num_samples:
322324
raise RuntimeError(f'Invalid split definition, no samples found.')
325+
self.remap_class = False
326+
if class_map:
327+
self.class_to_idx = load_class_map(class_map)
328+
self.remap_class = True
329+
else:
330+
self.class_to_idx = {}
323331

324332
# Distributed world state
325333
self.dist_rank = 0
@@ -431,7 +439,10 @@ def __iter__(self):
431439
i = 0
432440
# _logger.info(f'start {i}, {self.worker_id}') # FIXME temporary debug
433441
for sample in ds:
434-
yield sample[self.image_key], sample[self.target_key]
442+
target = sample[self.target_key]
443+
if self.remap_class:
444+
target = self.class_to_idx[target]
445+
yield sample[self.image_key], target
435446
i += 1
436447
# _logger.info(f'end {i}, {self.worker_id}') # FIXME temporary debug
437448

timm/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
EvoNorm2dS0, EvoNorm2dS0a, EvoNorm2dS1, EvoNorm2dS1a, EvoNorm2dS2, EvoNorm2dS2a
2121
from .fast_norm import is_fast_norm, set_fast_norm, fast_group_norm, fast_layer_norm
2222
from .filter_response_norm import FilterResponseNormTlu2d, FilterResponseNormAct2d
23+
from .format import Format, get_channel_dim, get_spatial_dim, nchw_to, nhwc_to
2324
from .gather_excite import GatherExcite
2425
from .global_context import GlobalContext
2526
from .helpers import to_ntuple, to_2tuple, to_3tuple, to_4tuple, make_divisible, extend_tuple

0 commit comments

Comments
 (0)