@@ -167,6 +167,15 @@ def _natural_key(string_: str) -> List[Union[int, str]]:
167
167
return [int (s ) if s .isdigit () else s for s in re .split (r'(\d+)' , string_ .lower ())]
168
168
169
169
170
+ def _expand_filter (filter : str ):
171
+ """ expand a 'base_filter' to 'base_filter.*' if no tag portion"""
172
+ filter_base , filter_tag = split_model_name_tag (filter )
173
+ if not filter_tag :
174
+ return ['.' .join ([filter_base , '*' ]), filter ]
175
+ else :
176
+ return [filter ]
177
+
178
+
170
179
def list_models (
171
180
filter : Union [str , List [str ]] = '' ,
172
181
module : str = '' ,
@@ -193,6 +202,11 @@ def list_models(
193
202
model_list('gluon_resnet*') -- returns all models starting with 'gluon_resnet'
194
203
model_list('*resnext*, 'resnet') -- returns all models with 'resnext' in 'resnet' module
195
204
"""
205
+ if filter :
206
+ include_filters = filter if isinstance (filter , (tuple , list )) else [filter ]
207
+ else :
208
+ include_filters = []
209
+
196
210
if include_tags is None :
197
211
# FIXME should this be default behaviour? or default to include_tags=True?
198
212
include_tags = pretrained
@@ -206,10 +220,12 @@ def list_models(
206
220
for m in all_models :
207
221
models_with_tags .update (_model_with_tags [m ])
208
222
all_models = models_with_tags
223
+ # expand include and exclude filters to include a '.*' for proper match if no tags in filter
224
+ include_filters = [ef for f in include_filters for ef in _expand_filter (f )]
225
+ exclude_filters = [ef for f in exclude_filters for ef in _expand_filter (f )]
209
226
210
- if filter :
227
+ if include_filters :
211
228
models : Set [str ] = set ()
212
- include_filters = filter if isinstance (filter , (tuple , list )) else [filter ]
213
229
for f in include_filters :
214
230
include_models = fnmatch .filter (all_models , f ) # include these models
215
231
if len (include_models ):
0 commit comments