8
8
from collections import OrderedDict , defaultdict
9
9
from copy import deepcopy
10
10
from functools import partial
11
- from typing import Dict , List , Tuple , Any
11
+ from typing import Dict , List , Tuple
12
12
13
13
import torch
14
14
import torch .nn as nn
@@ -30,42 +30,46 @@ def __init__(self, feature_info: List[Dict], out_indices: Tuple[int]):
30
30
def from_other (self , out_indices : Tuple [int ]):
31
31
return FeatureInfo (deepcopy (self .info ), out_indices )
32
32
33
+ def get (self , key , idx = None ):
34
+ """ Get value by key at specified index (indices)
35
+ if idx == None, returns value for key at each output index
36
+ if idx is an integer, return value for that feature module index (ignoring output indices)
37
+ if idx is a list/tupple, return value for each module index (ignoring output indices)
38
+ """
39
+ if idx is None :
40
+ return [self .info [i ][key ] for i in self .out_indices ]
41
+ if isinstance (idx , (tuple , list )):
42
+ return [self .info [i ][key ] for i in idx ]
43
+ else :
44
+ return self .info [idx ][key ]
45
+
46
+ def get_dicts (self , keys = None , idx = None ):
47
+ """ return info dicts for specified keys (or all if None) at specified indices (or out_indices if None)
48
+ """
49
+ if idx is None :
50
+ if keys is None :
51
+ return [self .info [i ] for i in self .out_indices ]
52
+ else :
53
+ return [{k : self .info [i ][k ] for k in keys } for i in self .out_indices ]
54
+ if isinstance (idx , (tuple , list )):
55
+ return [self .info [i ] if keys is None else {k : self .info [i ][k ] for k in keys } for i in idx ]
56
+ else :
57
+ return self .info [idx ] if keys is None else {k : self .info [idx ][k ] for k in keys }
58
+
33
59
def channels (self , idx = None ):
34
60
""" feature channels accessor
35
- if idx == None, returns feature channel count at each output index
36
- if idx is an integer, return feature channel count for that feature module index
37
61
"""
38
- if isinstance (idx , int ):
39
- return self .info [idx ]['num_chs' ]
40
- return [self .info [i ]['num_chs' ] for i in self .out_indices ]
62
+ return self .get ('num_chs' , idx )
41
63
42
64
def reduction (self , idx = None ):
43
65
""" feature reduction (output stride) accessor
44
- if idx == None, returns feature reduction factor at each output index
45
- if idx is an integer, return feature channel count at that feature module index
46
66
"""
47
- if isinstance (idx , int ):
48
- return self .info [idx ]['reduction' ]
49
- return [self .info [i ]['reduction' ] for i in self .out_indices ]
67
+ return self .get ('reduction' , idx )
50
68
51
69
def module_name (self , idx = None ):
52
70
""" feature module name accessor
53
- if idx == None, returns feature module name at each output index
54
- if idx is an integer, return feature module name at that feature module index
55
- """
56
- if isinstance (idx , int ):
57
- return self .info [idx ]['module' ]
58
- return [self .info [i ]['module' ] for i in self .out_indices ]
59
-
60
- def get_by_key (self , idx = None , keys = None ):
61
- """ return info dicts for specified keys (or all if None) at specified idx (or out_indices if None)
62
71
"""
63
- if isinstance (idx , int ):
64
- return self .info [idx ] if keys is None else {k : self .info [idx ][k ] for k in keys }
65
- if keys is None :
66
- return [self .info [i ] for i in self .out_indices ]
67
- else :
68
- return [{k : self .info [i ][k ] for k in keys } for i in self .out_indices ]
72
+ return self .get ('module' , idx )
69
73
70
74
def __getitem__ (self , item ):
71
75
return self .info [item ]
@@ -253,11 +257,11 @@ def __init__(
253
257
if hasattr (model , 'reset_classifier' ): # make sure classifier is removed?
254
258
model .reset_classifier (0 )
255
259
layers ['body' ] = model
256
- hooks .extend (self .feature_info )
260
+ hooks .extend (self .feature_info . get_dicts () )
257
261
else :
258
262
modules = _module_list (model , flatten_sequential = flatten_sequential )
259
263
remaining = {f ['module' ]: f ['hook_type' ] if 'hook_type' in f else default_hook_type
260
- for f in self .feature_info }
264
+ for f in self .feature_info . get_dicts () }
261
265
for new_name , old_name , module in modules :
262
266
layers [new_name ] = module
263
267
for fn , fm in module .named_modules (prefix = old_name ):
0 commit comments