1
1
""" Activation Factory
2
2
Hacked together by / Copyright 2020 Ross Wightman
3
3
"""
4
+ from typing import Union , Callable , Type
5
+
4
6
from .activations import *
5
7
from .activations_jit import *
6
8
from .activations_me import *
7
9
from .config import is_exportable , is_scriptable , is_no_jit
8
10
9
- # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code
10
- # will use native version if present. Eventually, the custom Swish layers will be removed
11
- # and only native 'silu' will be used.
11
+ # PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
12
+ # Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
13
+ # Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
12
14
_has_silu = 'silu' in dir (torch .nn .functional )
15
+ _has_hardswish = 'hardswish' in dir (torch .nn .functional )
16
+ _has_hardsigmoid = 'hardsigmoid' in dir (torch .nn .functional )
17
+ _has_mish = 'mish' in dir (torch .nn .functional )
18
+
13
19
14
20
_ACT_FN_DEFAULT = dict (
15
21
silu = F .silu if _has_silu else swish ,
16
22
swish = F .silu if _has_silu else swish ,
17
- mish = mish ,
23
+ mish = F . mish if _has_mish else mish ,
18
24
relu = F .relu ,
19
25
relu6 = F .relu6 ,
20
26
leaky_relu = F .leaky_relu ,
24
30
gelu = gelu ,
25
31
sigmoid = sigmoid ,
26
32
tanh = tanh ,
27
- hard_sigmoid = hard_sigmoid ,
28
- hard_swish = hard_swish ,
33
+ hard_sigmoid = F . hardsigmoid if _has_hardsigmoid else hard_sigmoid ,
34
+ hard_swish = F . hardswish if _has_hardswish else hard_swish ,
29
35
hard_mish = hard_mish ,
30
36
)
31
37
32
38
_ACT_FN_JIT = dict (
33
39
silu = F .silu if _has_silu else swish_jit ,
34
40
swish = F .silu if _has_silu else swish_jit ,
35
- mish = mish_jit ,
36
- hard_sigmoid = hard_sigmoid_jit ,
37
- hard_swish = hard_swish_jit ,
41
+ mish = F . mish if _has_mish else mish_jit ,
42
+ hard_sigmoid = F . hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit ,
43
+ hard_swish = F . hardswish if _has_hardswish else hard_swish_jit ,
38
44
hard_mish = hard_mish_jit
39
45
)
40
46
41
47
_ACT_FN_ME = dict (
42
48
silu = F .silu if _has_silu else swish_me ,
43
49
swish = F .silu if _has_silu else swish_me ,
44
- mish = mish_me ,
45
- hard_sigmoid = hard_sigmoid_me ,
46
- hard_swish = hard_swish_me ,
50
+ mish = F . mish if _has_mish else mish_me ,
51
+ hard_sigmoid = F . hardsigmoid if _has_hardsigmoid else hard_sigmoid_me ,
52
+ hard_swish = F . hardswish if _has_hardswish else hard_swish_me ,
47
53
hard_mish = hard_mish_me ,
48
54
)
49
55
56
+ _ACT_FNS = (_ACT_FN_ME , _ACT_FN_JIT , _ACT_FN_DEFAULT )
57
+ for a in _ACT_FNS :
58
+ a .setdefault ('hardsigmoid' , a .get ('hard_sigmoid' ))
59
+ a .setdefault ('hardswish' , a .get ('hard_swish' ))
60
+
61
+
50
62
_ACT_LAYER_DEFAULT = dict (
51
63
silu = nn .SiLU if _has_silu else Swish ,
52
64
swish = nn .SiLU if _has_silu else Swish ,
53
- mish = Mish ,
65
+ mish = nn . Mish if _has_mish else Mish ,
54
66
relu = nn .ReLU ,
55
67
relu6 = nn .ReLU6 ,
56
68
leaky_relu = nn .LeakyReLU ,
61
73
gelu = GELU ,
62
74
sigmoid = Sigmoid ,
63
75
tanh = Tanh ,
64
- hard_sigmoid = HardSigmoid ,
65
- hard_swish = HardSwish ,
76
+ hard_sigmoid = nn . Hardsigmoid if _has_hardsigmoid else HardSigmoid ,
77
+ hard_swish = nn . Hardswish if _has_hardswish else HardSwish ,
66
78
hard_mish = HardMish ,
67
79
)
68
80
69
81
_ACT_LAYER_JIT = dict (
70
82
silu = nn .SiLU if _has_silu else SwishJit ,
71
83
swish = nn .SiLU if _has_silu else SwishJit ,
72
- mish = MishJit ,
73
- hard_sigmoid = HardSigmoidJit ,
74
- hard_swish = HardSwishJit ,
84
+ mish = nn . Mish if _has_mish else MishJit ,
85
+ hard_sigmoid = nn . Hardsigmoid if _has_hardsigmoid else HardSigmoidJit ,
86
+ hard_swish = nn . Hardswish if _has_hardswish else HardSwishJit ,
75
87
hard_mish = HardMishJit
76
88
)
77
89
78
90
_ACT_LAYER_ME = dict (
79
91
silu = nn .SiLU if _has_silu else SwishMe ,
80
92
swish = nn .SiLU if _has_silu else SwishMe ,
81
- mish = MishMe ,
82
- hard_sigmoid = HardSigmoidMe ,
83
- hard_swish = HardSwishMe ,
93
+ mish = nn . Mish if _has_mish else MishMe ,
94
+ hard_sigmoid = nn . Hardsigmoid if _has_hardsigmoid else HardSigmoidMe ,
95
+ hard_swish = nn . Hardswish if _has_hardswish else HardSwishMe ,
84
96
hard_mish = HardMishMe ,
85
97
)
86
98
99
+ _ACT_LAYERS = (_ACT_LAYER_ME , _ACT_LAYER_JIT , _ACT_LAYER_DEFAULT )
100
+ for a in _ACT_LAYERS :
101
+ a .setdefault ('hardsigmoid' , a .get ('hard_sigmoid' ))
102
+ a .setdefault ('hardswish' , a .get ('hard_swish' ))
103
+
87
104
88
- def get_act_fn (name = 'relu' ):
105
+ def get_act_fn (name : Union [ Callable , str ] = 'relu' ):
89
106
""" Activation Function Factory
90
107
Fetching activation fns by name with this function allows export or torch script friendly
91
108
functions to be returned dynamically based on current config.
92
109
"""
93
110
if not name :
94
111
return None
112
+ if isinstance (name , Callable ):
113
+ return name
95
114
if not (is_no_jit () or is_exportable () or is_scriptable ()):
96
115
# If not exporting or scripting the model, first look for a memory-efficient version with
97
116
# custom autograd, then fallback
@@ -106,13 +125,15 @@ def get_act_fn(name='relu'):
106
125
return _ACT_FN_DEFAULT [name ]
107
126
108
127
109
- def get_act_layer (name = 'relu' ):
128
+ def get_act_layer (name : Union [ Type [ nn . Module ], str ] = 'relu' ):
110
129
""" Activation Layer Factory
111
130
Fetching activation layers by name with this function allows export or torch script friendly
112
131
functions to be returned dynamically based on current config.
113
132
"""
114
133
if not name :
115
134
return None
135
+ if isinstance (name , type ):
136
+ return name
116
137
if not (is_no_jit () or is_exportable () or is_scriptable ()):
117
138
if name in _ACT_LAYER_ME :
118
139
return _ACT_LAYER_ME [name ]
@@ -125,9 +146,8 @@ def get_act_layer(name='relu'):
125
146
return _ACT_LAYER_DEFAULT [name ]
126
147
127
148
128
- def create_act_layer (name , inplace = False , ** kwargs ):
149
+ def create_act_layer (name : Union [ nn . Module , str ], inplace = None , ** kwargs ):
129
150
act_layer = get_act_layer (name )
130
- if act_layer is not None :
131
- return act_layer (inplace = inplace , ** kwargs )
132
- else :
151
+ if act_layer is None :
133
152
return None
153
+ return act_layer (** kwargs ) if inplace is None else act_layer (inplace = inplace , ** kwargs )
0 commit comments