Skip to content

Commit 9c78de8

Browse files
committed
Fix huggingface#661, move hardswish out of default args for LeViT. Enable native torch support for hardswish, hardsigmoid, mish if present.
1 parent 07d952c commit 9c78de8

File tree

8 files changed

+66
-47
lines changed

8 files changed

+66
-47
lines changed

tests/test_layers.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@
88

99

1010
class MLP(nn.Module):
11-
def __init__(self, act_layer="relu"):
11+
def __init__(self, act_layer="relu", inplace=True):
1212
super(MLP, self).__init__()
1313
self.fc1 = nn.Linear(1000, 100)
14-
self.act = create_act_layer(act_layer, inplace=True)
14+
self.act = create_act_layer(act_layer, inplace=inplace)
1515
self.fc2 = nn.Linear(100, 10)
1616

1717
def forward(self, x):
@@ -21,14 +21,14 @@ def forward(self, x):
2121
return x
2222

2323

24-
def _run_act_layer_grad(act_type):
24+
def _run_act_layer_grad(act_type, inplace=True):
2525
x = torch.rand(10, 1000) * 10
26-
m = MLP(act_layer=act_type)
26+
m = MLP(act_layer=act_type, inplace=inplace)
2727

2828
def _run(x, act_layer=''):
2929
if act_layer:
3030
# replace act layer if set
31-
m.act = create_act_layer(act_layer, inplace=True)
31+
m.act = create_act_layer(act_layer, inplace=inplace)
3232
out = m(x)
3333
l = (out - 0).pow(2).sum()
3434
return l
@@ -58,7 +58,7 @@ def test_mish_grad():
5858

5959
def test_hard_sigmoid_grad():
6060
for _ in range(100):
61-
_run_act_layer_grad('hard_sigmoid')
61+
_run_act_layer_grad('hard_sigmoid', inplace=None)
6262

6363

6464
def test_hard_swish_grad():

tests/test_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test_model_backward(model_name, batch_size):
110110
assert not torch.isnan(outputs).any(), 'Output included NaNs'
111111

112112

113-
@pytest.mark.timeout(120)
113+
@pytest.mark.timeout(300)
114114
@pytest.mark.parametrize('model_name', list_models(exclude_filters=NON_STD_FILTERS))
115115
@pytest.mark.parametrize('batch_size', [1])
116116
def test_model_default_cfgs(model_name, batch_size):

timm/models/efficientnet_blocks.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.nn as nn
88
from torch.nn import functional as F
99

10-
from .layers import create_conv2d, drop_path, make_divisible
10+
from .layers import create_conv2d, drop_path, make_divisible, get_act_fn, create_act_layer
1111
from .layers.activations import sigmoid
1212

1313
__all__ = [
@@ -36,9 +36,9 @@ def __init__(
3636
reduced_chs = make_divisible(reduced_chs * se_ratio, divisor)
3737
act_layer = force_act_layer or act_layer
3838
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
39-
self.act1 = act_layer(inplace=True)
39+
self.act1 = create_act_layer(act_layer, inplace=True)
4040
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
41-
self.gate_fn = gate_fn
41+
self.gate_fn = get_act_fn(gate_fn)
4242

4343
def forward(self, x):
4444
x_se = x.mean((2, 3), keepdim=True)

timm/models/efficientnet_builder.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,7 @@ def resolve_bn_args(kwargs):
5050

5151

5252
def resolve_act_layer(kwargs, default='relu'):
53-
act_layer = kwargs.pop('act_layer', default)
54-
if isinstance(act_layer, str):
55-
act_layer = get_act_layer(act_layer)
56-
return act_layer
53+
return get_act_layer(kwargs.pop('act_layer', default))
5754

5855

5956
def round_channels(channels, multiplier=1.0, divisor=8, channel_min=None, round_limit=0.9):

timm/models/ghostnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414

1515
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16-
from .layers import SelectAdaptivePool2d, Linear, hard_sigmoid, make_divisible
16+
from .layers import SelectAdaptivePool2d, Linear, make_divisible
1717
from .efficientnet_blocks import SqueezeExcite, ConvBnAct
1818
from .helpers import build_model_with_cfg
1919
from .registry import register_model
@@ -40,7 +40,7 @@ def _cfg(url='', **kwargs):
4040
}
4141

4242

43-
_SE_LAYER = partial(SqueezeExcite, gate_fn=hard_sigmoid, divisor=4)
43+
_SE_LAYER = partial(SqueezeExcite, gate_fn='hard_sigmoid', divisor=4)
4444

4545

4646
class GhostModule(nn.Module):

timm/models/layers/create_act.py

Lines changed: 47 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,26 @@
11
""" Activation Factory
22
Hacked together by / Copyright 2020 Ross Wightman
33
"""
4+
from typing import Union, Callable, Type
5+
46
from .activations import *
57
from .activations_jit import *
68
from .activations_me import *
79
from .config import is_exportable, is_scriptable, is_no_jit
810

9-
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7. This code
10-
# will use native version if present. Eventually, the custom Swish layers will be removed
11-
# and only native 'silu' will be used.
11+
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
12+
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
13+
# Eventually, the custom SiLU, Mish, Hard*, layers will be removed and only native variants will be used.
1214
_has_silu = 'silu' in dir(torch.nn.functional)
15+
_has_hardswish = 'hardswish' in dir(torch.nn.functional)
16+
_has_hardsigmoid = 'hardsigmoid' in dir(torch.nn.functional)
17+
_has_mish = 'mish' in dir(torch.nn.functional)
18+
1319

1420
_ACT_FN_DEFAULT = dict(
1521
silu=F.silu if _has_silu else swish,
1622
swish=F.silu if _has_silu else swish,
17-
mish=mish,
23+
mish=F.mish if _has_mish else mish,
1824
relu=F.relu,
1925
relu6=F.relu6,
2026
leaky_relu=F.leaky_relu,
@@ -24,33 +30,39 @@
2430
gelu=gelu,
2531
sigmoid=sigmoid,
2632
tanh=tanh,
27-
hard_sigmoid=hard_sigmoid,
28-
hard_swish=hard_swish,
33+
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid,
34+
hard_swish=F.hardswish if _has_hardswish else hard_swish,
2935
hard_mish=hard_mish,
3036
)
3137

3238
_ACT_FN_JIT = dict(
3339
silu=F.silu if _has_silu else swish_jit,
3440
swish=F.silu if _has_silu else swish_jit,
35-
mish=mish_jit,
36-
hard_sigmoid=hard_sigmoid_jit,
37-
hard_swish=hard_swish_jit,
41+
mish=F.mish if _has_mish else mish_jit,
42+
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_jit,
43+
hard_swish=F.hardswish if _has_hardswish else hard_swish_jit,
3844
hard_mish=hard_mish_jit
3945
)
4046

4147
_ACT_FN_ME = dict(
4248
silu=F.silu if _has_silu else swish_me,
4349
swish=F.silu if _has_silu else swish_me,
44-
mish=mish_me,
45-
hard_sigmoid=hard_sigmoid_me,
46-
hard_swish=hard_swish_me,
50+
mish=F.mish if _has_mish else mish_me,
51+
hard_sigmoid=F.hardsigmoid if _has_hardsigmoid else hard_sigmoid_me,
52+
hard_swish=F.hardswish if _has_hardswish else hard_swish_me,
4753
hard_mish=hard_mish_me,
4854
)
4955

56+
_ACT_FNS = (_ACT_FN_ME, _ACT_FN_JIT, _ACT_FN_DEFAULT)
57+
for a in _ACT_FNS:
58+
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
59+
a.setdefault('hardswish', a.get('hard_swish'))
60+
61+
5062
_ACT_LAYER_DEFAULT = dict(
5163
silu=nn.SiLU if _has_silu else Swish,
5264
swish=nn.SiLU if _has_silu else Swish,
53-
mish=Mish,
65+
mish=nn.Mish if _has_mish else Mish,
5466
relu=nn.ReLU,
5567
relu6=nn.ReLU6,
5668
leaky_relu=nn.LeakyReLU,
@@ -61,37 +73,44 @@
6173
gelu=GELU,
6274
sigmoid=Sigmoid,
6375
tanh=Tanh,
64-
hard_sigmoid=HardSigmoid,
65-
hard_swish=HardSwish,
76+
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoid,
77+
hard_swish=nn.Hardswish if _has_hardswish else HardSwish,
6678
hard_mish=HardMish,
6779
)
6880

6981
_ACT_LAYER_JIT = dict(
7082
silu=nn.SiLU if _has_silu else SwishJit,
7183
swish=nn.SiLU if _has_silu else SwishJit,
72-
mish=MishJit,
73-
hard_sigmoid=HardSigmoidJit,
74-
hard_swish=HardSwishJit,
84+
mish=nn.Mish if _has_mish else MishJit,
85+
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidJit,
86+
hard_swish=nn.Hardswish if _has_hardswish else HardSwishJit,
7587
hard_mish=HardMishJit
7688
)
7789

7890
_ACT_LAYER_ME = dict(
7991
silu=nn.SiLU if _has_silu else SwishMe,
8092
swish=nn.SiLU if _has_silu else SwishMe,
81-
mish=MishMe,
82-
hard_sigmoid=HardSigmoidMe,
83-
hard_swish=HardSwishMe,
93+
mish=nn.Mish if _has_mish else MishMe,
94+
hard_sigmoid=nn.Hardsigmoid if _has_hardsigmoid else HardSigmoidMe,
95+
hard_swish=nn.Hardswish if _has_hardswish else HardSwishMe,
8496
hard_mish=HardMishMe,
8597
)
8698

99+
_ACT_LAYERS = (_ACT_LAYER_ME, _ACT_LAYER_JIT, _ACT_LAYER_DEFAULT)
100+
for a in _ACT_LAYERS:
101+
a.setdefault('hardsigmoid', a.get('hard_sigmoid'))
102+
a.setdefault('hardswish', a.get('hard_swish'))
103+
87104

88-
def get_act_fn(name='relu'):
105+
def get_act_fn(name: Union[Callable, str] = 'relu'):
89106
""" Activation Function Factory
90107
Fetching activation fns by name with this function allows export or torch script friendly
91108
functions to be returned dynamically based on current config.
92109
"""
93110
if not name:
94111
return None
112+
if isinstance(name, Callable):
113+
return name
95114
if not (is_no_jit() or is_exportable() or is_scriptable()):
96115
# If not exporting or scripting the model, first look for a memory-efficient version with
97116
# custom autograd, then fallback
@@ -106,13 +125,15 @@ def get_act_fn(name='relu'):
106125
return _ACT_FN_DEFAULT[name]
107126

108127

109-
def get_act_layer(name='relu'):
128+
def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
110129
""" Activation Layer Factory
111130
Fetching activation layers by name with this function allows export or torch script friendly
112131
functions to be returned dynamically based on current config.
113132
"""
114133
if not name:
115134
return None
135+
if isinstance(name, type):
136+
return name
116137
if not (is_no_jit() or is_exportable() or is_scriptable()):
117138
if name in _ACT_LAYER_ME:
118139
return _ACT_LAYER_ME[name]
@@ -125,9 +146,8 @@ def get_act_layer(name='relu'):
125146
return _ACT_LAYER_DEFAULT[name]
126147

127148

128-
def create_act_layer(name, inplace=False, **kwargs):
149+
def create_act_layer(name: Union[nn.Module, str], inplace=None, **kwargs):
129150
act_layer = get_act_layer(name)
130-
if act_layer is not None:
131-
return act_layer(inplace=inplace, **kwargs)
132-
else:
151+
if act_layer is None:
133152
return None
153+
return act_layer(**kwargs) if inplace is None else act_layer(inplace=inplace, **kwargs)

timm/models/layers/se.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class EffectiveSEModule(nn.Module):
4242
def __init__(self, channels, gate_layer='hard_sigmoid'):
4343
super(EffectiveSEModule, self).__init__()
4444
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
45-
self.gate = create_act_layer(gate_layer, inplace=True)
45+
self.gate = create_act_layer(gate_layer)
4646

4747
def forward(self, x):
4848
x_se = x.mean((2, 3), keepdim=True)

timm/models/levit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from timm.data import IMAGENET_DEFAULT_STD, IMAGENET_DEFAULT_MEAN
3535
from .helpers import build_model_with_cfg, overlay_external_default_cfg
36-
from .layers import to_ntuple
36+
from .layers import to_ntuple, get_act_layer
3737
from .vision_transformer import trunc_normal_
3838
from .registry import register_model
3939

@@ -443,12 +443,14 @@ def __init__(
443443
mlp_ratio=2,
444444
hybrid_backbone=None,
445445
down_ops=None,
446-
act_layer=nn.Hardswish,
447-
attn_act_layer=nn.Hardswish,
446+
act_layer='hard_swish',
447+
attn_act_layer='hard_swish',
448448
distillation=True,
449449
use_conv=False,
450450
drop_path=0):
451451
super().__init__()
452+
act_layer = get_act_layer(act_layer)
453+
attn_act_layer = get_act_layer(attn_act_layer)
452454
if isinstance(img_size, tuple):
453455
# FIXME origin impl passes single img/res dim through whole hierarchy,
454456
# not sure this model will be used enough to spend time fixing it.

0 commit comments

Comments
 (0)