9
9
"""
10
10
import math
11
11
from functools import partial
12
- from typing import Any , Dict , List , Optional , Tuple , Type
12
+ from typing import Any , Dict , List , Optional , Tuple , Type , Union
13
13
14
14
import torch
15
15
import torch .nn as nn
16
16
import torch .nn .functional as F
17
17
from torch import Tensor
18
18
19
19
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
20
- from timm .layers import DropBlock2d , DropPath , AvgPool2dSame , BlurPool2d , GroupNorm , create_attn , get_attn , \
21
- get_act_layer , get_norm_layer , create_classifier
20
+ from timm .layers import DropBlock2d , DropPath , AvgPool2dSame , BlurPool2d , GroupNorm , LayerType , create_attn , \
21
+ get_attn , get_act_layer , get_norm_layer , create_classifier
22
22
from ._builder import build_model_with_cfg
23
23
from ._manipulate import checkpoint_seq
24
24
from ._registry import register_model , generate_default_cfgs , register_model_deprecations
25
- from ._typing import LayerType
26
25
27
26
__all__ = ['ResNet' , 'BasicBlock' , 'Bottleneck' ] # model_registry will add each entrypoint fn to this
28
27
29
28
30
- def get_padding (kernel_size : int , stride : int , dilation : int = 1 ):
29
+ def get_padding (kernel_size : int , stride : int , dilation : int = 1 ) -> int :
31
30
padding = ((stride - 1 ) + dilation * (kernel_size - 1 )) // 2
32
31
return padding
33
32
34
33
35
- def create_aa (aa_layer , channels , stride = 2 , enable = True ):
34
+ def create_aa (aa_layer : Type [ nn . Module ] , channels : int , stride : int = 2 , enable : bool = True ) -> nn . Module :
36
35
if not aa_layer or not enable :
37
36
return nn .Identity ()
38
37
if issubclass (aa_layer , nn .AvgPool2d ):
@@ -55,11 +54,11 @@ def __init__(
55
54
reduce_first : int = 1 ,
56
55
dilation : int = 1 ,
57
56
first_dilation : Optional [int ] = None ,
58
- act_layer : nn .Module = nn .ReLU ,
59
- norm_layer : nn .Module = nn .BatchNorm2d ,
60
- attn_layer : Optional [nn .Module ] = None ,
61
- aa_layer : Optional [nn .Module ] = None ,
62
- drop_block : Type [nn .Module ] = None ,
57
+ act_layer : Type [ nn .Module ] = nn .ReLU ,
58
+ norm_layer : Type [ nn .Module ] = nn .BatchNorm2d ,
59
+ attn_layer : Optional [Type [ nn .Module ] ] = None ,
60
+ aa_layer : Optional [Type [ nn .Module ] ] = None ,
61
+ drop_block : Optional [ Type [nn .Module ] ] = None ,
63
62
drop_path : Optional [nn .Module ] = None ,
64
63
):
65
64
"""
@@ -153,11 +152,11 @@ def __init__(
153
152
reduce_first : int = 1 ,
154
153
dilation : int = 1 ,
155
154
first_dilation : Optional [int ] = None ,
156
- act_layer : nn .Module = nn .ReLU ,
157
- norm_layer : nn .Module = nn .BatchNorm2d ,
158
- attn_layer : Optional [nn .Module ] = None ,
159
- aa_layer : Optional [nn .Module ] = None ,
160
- drop_block : Type [nn .Module ] = None ,
155
+ act_layer : Type [ nn .Module ] = nn .ReLU ,
156
+ norm_layer : Type [ nn .Module ] = nn .BatchNorm2d ,
157
+ attn_layer : Optional [Type [ nn .Module ] ] = None ,
158
+ aa_layer : Optional [Type [ nn .Module ] ] = None ,
159
+ drop_block : Optional [ Type [nn .Module ] ] = None ,
161
160
drop_path : Optional [nn .Module ] = None ,
162
161
):
163
162
"""
@@ -296,7 +295,7 @@ def drop_blocks(drop_prob: float = 0.):
296
295
297
296
298
297
def make_blocks (
299
- block_fn : nn . Module ,
298
+ block_fn : Union [ BasicBlock , Bottleneck ] ,
300
299
channels : List [int ],
301
300
block_repeats : List [int ],
302
301
inplanes : int ,
@@ -395,7 +394,7 @@ class ResNet(nn.Module):
395
394
396
395
def __init__ (
397
396
self ,
398
- block : nn . Module ,
397
+ block : Union [ BasicBlock , Bottleneck ] ,
399
398
layers : List [int ],
400
399
num_classes : int = 1000 ,
401
400
in_chans : int = 3 ,
@@ -411,7 +410,7 @@ def __init__(
411
410
avg_down : bool = False ,
412
411
act_layer : LayerType = nn .ReLU ,
413
412
norm_layer : LayerType = nn .BatchNorm2d ,
414
- aa_layer : Optional [nn .Module ] = None ,
413
+ aa_layer : Optional [Type [ nn .Module ] ] = None ,
415
414
drop_rate : float = 0.0 ,
416
415
drop_path_rate : float = 0. ,
417
416
drop_block_rate : float = 0. ,
0 commit comments