Skip to content

Commit b3e816d

Browse files
committed
Improve filtering behaviour for tag + non-tagged model wildcard consistency.
1 parent 7aba64e commit b3e816d

File tree

3 files changed

+30
-5
lines changed

3 files changed

+30
-5
lines changed

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@
5252
'*efficientnet_l2*', '*resnext101_32x48d', '*in21k', '*152x4_bitm', '*101x3_bitm', '*50x3_bitm',
5353
'*nfnet_f3*', '*nfnet_f4*', '*nfnet_f5*', '*nfnet_f6*', '*nfnet_f7*', '*efficientnetv2_xl*',
5454
'*resnetrs350*', '*resnetrs420*', 'xcit_large_24_p8*', 'vit_huge*', 'vit_gi*', 'swin*huge*',
55-
'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge']
55+
'swin*giant*', 'convnextv2_huge*', 'maxvit_xlarge*', 'davit_giant', 'davit_huge', 'regnet*1280', 'regnet*2560']
5656
NON_STD_EXCLUDE_FILTERS = ['vit_huge*', 'vit_gi*', 'eva_giant*']
5757
else:
5858
EXCLUDE_FILTERS = []

timm/models/_registry.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,15 @@ def _natural_key(string_: str) -> List[Union[int, str]]:
167167
return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
168168

169169

170+
def _expand_filter(filter: str):
171+
""" expand a 'base_filter' to 'base_filter.*' if no tag portion"""
172+
filter_base, filter_tag = split_model_name_tag(filter)
173+
if not filter_tag:
174+
return ['.'.join([filter_base, '*']), filter]
175+
else:
176+
return [filter]
177+
178+
170179
def list_models(
171180
filter: Union[str, List[str]] = '',
172181
module: str = '',
@@ -193,6 +202,11 @@ def list_models(
193202
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
194203
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
195204
"""
205+
if filter:
206+
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
207+
else:
208+
include_filters = []
209+
196210
if include_tags is None:
197211
# FIXME should this be default behaviour? or default to include_tags=True?
198212
include_tags = pretrained
@@ -206,10 +220,12 @@ def list_models(
206220
for m in all_models:
207221
models_with_tags.update(_model_with_tags[m])
208222
all_models = models_with_tags
223+
# expand include and exclude filters to include a '.*' for proper match if no tags in filter
224+
include_filters = [ef for f in include_filters for ef in _expand_filter(f)]
225+
exclude_filters = [ef for f in exclude_filters for ef in _expand_filter(f)]
209226

210-
if filter:
227+
if include_filters:
211228
models: Set[str] = set()
212-
include_filters = filter if isinstance(filter, (tuple, list)) else [filter]
213229
for f in include_filters:
214230
include_models = fnmatch.filter(all_models, f) # include these models
215231
if len(include_models):

validate.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -385,6 +385,9 @@ def _try_run(args, initial_batch_size):
385385
return results
386386

387387

388+
_NON_IN1K_FILTERS = ['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae', '*seer']
389+
390+
388391
def main():
389392
setup_default_logging()
390393
args = parser.parse_args()
@@ -400,11 +403,17 @@ def main():
400403
if args.model == 'all':
401404
# validate all models in a list of names with pretrained checkpoints
402405
args.pretrained = True
403-
model_names = list_models('convnext*', pretrained=True, exclude_filters=['*_in21k', '*_in22k', '*in12k', '*_dino', '*fcmae'])
406+
model_names = list_models(
407+
pretrained=True,
408+
exclude_filters=_NON_IN1K_FILTERS,
409+
)
404410
model_cfgs = [(n, '') for n in model_names]
405411
elif not is_model(args.model):
406412
# model name doesn't exist, try as wildcard filter
407-
model_names = list_models(args.model, pretrained=True)
413+
model_names = list_models(
414+
args.model,
415+
pretrained=True,
416+
)
408417
model_cfgs = [(n, '') for n in model_names]
409418

410419
if not model_cfgs and os.path.isfile(args.model):

0 commit comments

Comments
 (0)