Skip to content

Commit 5f14bdd

Browse files
a-r-r-o-wrwightman
authored andcommitted
include typing suggestions by @rwightman
1 parent 05b0aac commit 5f14bdd

File tree

5 files changed

+25
-24
lines changed

5 files changed

+25
-24
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,4 +52,5 @@
5252
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame
5353
from .test_time_pool import TestTimePoolHead, apply_test_time_pool
5454
from .trace_utils import _assert, _float_to_int
55+
from .typing import LayerType, PadType
5556
from .weight_init import trunc_normal_, trunc_normal_tf_, variance_scaling_, lecun_normal_
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import functools
22
import types
3-
from typing import Any, Dict, List, Tuple, Union
3+
from typing import Tuple, Union
44

55
import torch.nn
66

77

8-
BlockArgs = List[List[Dict[str, Any]]]
98
LayerType = Union[type, str, types.FunctionType, functools.partial, torch.nn.Module]
109
PadType = Union[str, int, Tuple[int, int]]

timm/models/_efficientnet_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import re
1212
from copy import deepcopy
1313
from functools import partial
14+
from typing import Any, Dict, List
1415

1516
import torch.nn as nn
1617

@@ -34,6 +35,8 @@
3435
BN_EPS_TF_DEFAULT = 1e-3
3536
_BN_ARGS_TF = dict(momentum=BN_MOMENTUM_TF_DEFAULT, eps=BN_EPS_TF_DEFAULT)
3637

38+
BlockArgs = List[List[Dict[str, Any]]]
39+
3740

3841
def get_bn_args_tf():
3942
return _BN_ARGS_TF.copy()

timm/models/mobilenetv3.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,14 @@
1616
from torch.utils.checkpoint import checkpoint
1717

1818
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
19-
from timm.layers import SelectAdaptivePool2d, Linear, create_conv2d, get_norm_act_layer
19+
from timm.layers import SelectAdaptivePool2d, Linear, LayerType, PadType, create_conv2d, get_norm_act_layer
2020
from ._builder import build_model_with_cfg, pretrained_cfg_for_features
2121
from ._efficientnet_blocks import SqueezeExcite
22-
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
22+
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
2323
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
2424
from ._features import FeatureInfo, FeatureHooks
2525
from ._manipulate import checkpoint_seq
2626
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
27-
from ._typing import BlockArgs, LayerType, PadType
2827

2928
__all__ = ['MobileNetV3', 'MobileNetV3Features']
3029

timm/models/resnet.py

Lines changed: 18 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,30 +9,29 @@
99
"""
1010
import math
1111
from functools import partial
12-
from typing import Any, Dict, List, Optional, Tuple, Type
12+
from typing import Any, Dict, List, Optional, Tuple, Type, Union
1313

1414
import torch
1515
import torch.nn as nn
1616
import torch.nn.functional as F
1717
from torch import Tensor
1818

1919
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
20-
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, create_attn, get_attn, \
21-
get_act_layer, get_norm_layer, create_classifier
20+
from timm.layers import DropBlock2d, DropPath, AvgPool2dSame, BlurPool2d, GroupNorm, LayerType, create_attn, \
21+
get_attn, get_act_layer, get_norm_layer, create_classifier
2222
from ._builder import build_model_with_cfg
2323
from ._manipulate import checkpoint_seq
2424
from ._registry import register_model, generate_default_cfgs, register_model_deprecations
25-
from ._typing import LayerType
2625

2726
__all__ = ['ResNet', 'BasicBlock', 'Bottleneck'] # model_registry will add each entrypoint fn to this
2827

2928

30-
def get_padding(kernel_size: int, stride: int, dilation: int = 1):
29+
def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
3130
padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
3231
return padding
3332

3433

35-
def create_aa(aa_layer, channels, stride=2, enable=True):
34+
def create_aa(aa_layer: Type[nn.Module], channels: int, stride: int = 2, enable: bool = True) -> nn.Module:
3635
if not aa_layer or not enable:
3736
return nn.Identity()
3837
if issubclass(aa_layer, nn.AvgPool2d):
@@ -55,11 +54,11 @@ def __init__(
5554
reduce_first: int = 1,
5655
dilation: int = 1,
5756
first_dilation: Optional[int] = None,
58-
act_layer: nn.Module = nn.ReLU,
59-
norm_layer: nn.Module = nn.BatchNorm2d,
60-
attn_layer: Optional[nn.Module] = None,
61-
aa_layer: Optional[nn.Module] = None,
62-
drop_block: Type[nn.Module] = None,
57+
act_layer: Type[nn.Module] = nn.ReLU,
58+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
59+
attn_layer: Optional[Type[nn.Module]] = None,
60+
aa_layer: Optional[Type[nn.Module]] = None,
61+
drop_block: Optional[Type[nn.Module]] = None,
6362
drop_path: Optional[nn.Module] = None,
6463
):
6564
"""
@@ -153,11 +152,11 @@ def __init__(
153152
reduce_first: int = 1,
154153
dilation: int = 1,
155154
first_dilation: Optional[int] = None,
156-
act_layer: nn.Module = nn.ReLU,
157-
norm_layer: nn.Module = nn.BatchNorm2d,
158-
attn_layer: Optional[nn.Module] = None,
159-
aa_layer: Optional[nn.Module] = None,
160-
drop_block: Type[nn.Module] = None,
155+
act_layer: Type[nn.Module] = nn.ReLU,
156+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
157+
attn_layer: Optional[Type[nn.Module]] = None,
158+
aa_layer: Optional[Type[nn.Module]] = None,
159+
drop_block: Optional[Type[nn.Module]] = None,
161160
drop_path: Optional[nn.Module] = None,
162161
):
163162
"""
@@ -296,7 +295,7 @@ def drop_blocks(drop_prob: float = 0.):
296295

297296

298297
def make_blocks(
299-
block_fn: nn.Module,
298+
block_fn: Union[BasicBlock, Bottleneck],
300299
channels: List[int],
301300
block_repeats: List[int],
302301
inplanes: int,
@@ -395,7 +394,7 @@ class ResNet(nn.Module):
395394

396395
def __init__(
397396
self,
398-
block: nn.Module,
397+
block: Union[BasicBlock, Bottleneck],
399398
layers: List[int],
400399
num_classes: int = 1000,
401400
in_chans: int = 3,
@@ -411,7 +410,7 @@ def __init__(
411410
avg_down: bool = False,
412411
act_layer: LayerType = nn.ReLU,
413412
norm_layer: LayerType = nn.BatchNorm2d,
414-
aa_layer: Optional[nn.Module] = None,
413+
aa_layer: Optional[Type[nn.Module]] = None,
415414
drop_rate: float = 0.0,
416415
drop_path_rate: float = 0.,
417416
drop_block_rate: float = 0.,

0 commit comments

Comments
 (0)