Skip to content

Commit 08016e8

Browse files
committed
Cleanup FeatureInfo getters, add TF models sourced Xception41/65/71 weights
1 parent 7ba5a38 commit 08016e8

File tree

4 files changed

+40
-34
lines changed

4 files changed

+40
-34
lines changed

timm/models/efficientnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -441,7 +441,7 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bo
441441
# Register feature extraction hooks with FeatureHooks helper
442442
self.feature_hooks = None
443443
if feature_location != 'bottleneck':
444-
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
444+
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
445445
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
446446

447447
def forward(self, x) -> List[torch.Tensor]:

timm/models/features.py

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections import OrderedDict, defaultdict
99
from copy import deepcopy
1010
from functools import partial
11-
from typing import Dict, List, Tuple, Any
11+
from typing import Dict, List, Tuple
1212

1313
import torch
1414
import torch.nn as nn
@@ -30,42 +30,46 @@ def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
3030
def from_other(self, out_indices: Tuple[int]):
3131
return FeatureInfo(deepcopy(self.info), out_indices)
3232

33+
def get(self, key, idx=None):
34+
""" Get value by key at specified index (indices)
35+
if idx == None, returns value for key at each output index
36+
if idx is an integer, return value for that feature module index (ignoring output indices)
37+
if idx is a list/tupple, return value for each module index (ignoring output indices)
38+
"""
39+
if idx is None:
40+
return [self.info[i][key] for i in self.out_indices]
41+
if isinstance(idx, (tuple, list)):
42+
return [self.info[i][key] for i in idx]
43+
else:
44+
return self.info[idx][key]
45+
46+
def get_dicts(self, keys=None, idx=None):
47+
""" return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
48+
"""
49+
if idx is None:
50+
if keys is None:
51+
return [self.info[i] for i in self.out_indices]
52+
else:
53+
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
54+
if isinstance(idx, (tuple, list)):
55+
return [self.info[i] if keys is None else {k: self.info[i][k] for k in keys} for i in idx]
56+
else:
57+
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
58+
3359
def channels(self, idx=None):
3460
""" feature channels accessor
35-
if idx == None, returns feature channel count at each output index
36-
if idx is an integer, return feature channel count for that feature module index
3761
"""
38-
if isinstance(idx, int):
39-
return self.info[idx]['num_chs']
40-
return [self.info[i]['num_chs'] for i in self.out_indices]
62+
return self.get('num_chs', idx)
4163

4264
def reduction(self, idx=None):
4365
""" feature reduction (output stride) accessor
44-
if idx == None, returns feature reduction factor at each output index
45-
if idx is an integer, return feature channel count at that feature module index
4666
"""
47-
if isinstance(idx, int):
48-
return self.info[idx]['reduction']
49-
return [self.info[i]['reduction'] for i in self.out_indices]
67+
return self.get('reduction', idx)
5068

5169
def module_name(self, idx=None):
5270
""" feature module name accessor
53-
if idx == None, returns feature module name at each output index
54-
if idx is an integer, return feature module name at that feature module index
55-
"""
56-
if isinstance(idx, int):
57-
return self.info[idx]['module']
58-
return [self.info[i]['module'] for i in self.out_indices]
59-
60-
def get_by_key(self, idx=None, keys=None):
61-
""" return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
6271
"""
63-
if isinstance(idx, int):
64-
return self.info[idx] if keys is None else {k: self.info[idx][k] for k in keys}
65-
if keys is None:
66-
return [self.info[i] for i in self.out_indices]
67-
else:
68-
return [{k: self.info[i][k] for k in keys} for i in self.out_indices]
72+
return self.get('module', idx)
6973

7074
def __getitem__(self, item):
7175
return self.info[item]
@@ -253,11 +257,11 @@ def __init__(
253257
if hasattr(model, 'reset_classifier'): # make sure classifier is removed?
254258
model.reset_classifier(0)
255259
layers['body'] = model
256-
hooks.extend(self.feature_info)
260+
hooks.extend(self.feature_info.get_dicts())
257261
else:
258262
modules = _module_list(model, flatten_sequential=flatten_sequential)
259263
remaining = {f['module']: f['hook_type'] if 'hook_type' in f else default_hook_type
260-
for f in self.feature_info}
264+
for f in self.feature_info.get_dicts()}
261265
for new_name, old_name, module in modules:
262266
layers[new_name] = module
263267
for fn, fm in module.named_modules(prefix=old_name):

timm/models/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def __init__(self, block_args, out_indices=(0, 1, 2, 3, 4), feature_location='bo
186186
# Register feature extraction hooks with FeatureHooks helper
187187
self.feature_hooks = None
188188
if feature_location != 'bottleneck':
189-
hooks = self.feature_info.get_by_key(keys=('module', 'hook_type'))
189+
hooks = self.feature_info.get_dicts(keys=('module', 'hook_type'))
190190
self.feature_hooks = FeatureHooks(hooks, self.named_modules())
191191

192192
def forward(self, x) -> List[torch.Tensor]:

timm/models/xception_aligned.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,12 @@ def _cfg(url='', **kwargs):
3131

3232

3333
default_cfgs = dict(
34-
xception41=_cfg(url=''),
35-
xception65=_cfg(url=''),
36-
xception71=_cfg(url=''),
34+
xception41=_cfg(
35+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_41-e6439c97.pth'),
36+
xception65=_cfg(
37+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_65-c9ae96e8.pth'),
38+
xception71=_cfg(
39+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_xception_71-8eec7df1.pth'),
3740
)
3841

3942

@@ -216,7 +219,6 @@ def xception65(pretrained=False, **kwargs):
216219
return _xception('xception65', pretrained=pretrained, **model_args)
217220

218221

219-
220222
@register_model
221223
def xception71(pretrained=False, **kwargs):
222224
""" Modified Aligned Xception-71

0 commit comments

Comments
 (0)