8
8
Written by Bin Xiao (Bin.Xiao@microsoft.com)
9
9
Modified by Ke Sun (sunk@mail.ustc.edu.cn)
10
10
"""
11
-
12
- from __future__ import absolute_import
13
- from __future__ import division
14
- from __future__ import print_function
15
-
16
11
import logging
12
+ from typing import List
17
13
14
+ import torch
18
15
import torch .nn as nn
19
16
import torch .nn .functional as F
20
17
21
18
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
19
+ from .features import FeatureInfo
22
20
from .helpers import build_model_with_cfg
23
21
from .layers import SelectAdaptivePool2d
24
22
from .registry import register_model
@@ -403,32 +401,23 @@ def __init__(self, num_branches, blocks, num_blocks, num_inchannels,
403
401
self .branches = self ._make_branches (
404
402
num_branches , blocks , num_blocks , num_channels )
405
403
self .fuse_layers = self ._make_fuse_layers ()
406
- self .relu = nn .ReLU (False )
404
+ self .fuse_act = nn .ReLU (False )
407
405
408
406
def _check_branches (self , num_branches , blocks , num_blocks , num_inchannels , num_channels ):
407
+ error_msg = ''
409
408
if num_branches != len (num_blocks ):
410
- error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})' .format (
411
- num_branches , len (num_blocks ))
412
- logger .error (error_msg )
413
- raise ValueError (error_msg )
414
-
415
- if num_branches != len (num_channels ):
416
- error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})' .format (
417
- num_branches , len (num_channels ))
409
+ error_msg = 'NUM_BRANCHES({}) <> NUM_BLOCKS({})' .format (num_branches , len (num_blocks ))
410
+ elif num_branches != len (num_channels ):
411
+ error_msg = 'NUM_BRANCHES({}) <> NUM_CHANNELS({})' .format (num_branches , len (num_channels ))
412
+ elif num_branches != len (num_inchannels ):
413
+ error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})' .format (num_branches , len (num_inchannels ))
414
+ if error_msg :
418
415
logger .error (error_msg )
419
416
raise ValueError (error_msg )
420
417
421
- if num_branches != len (num_inchannels ):
422
- error_msg = 'NUM_BRANCHES({}) <> NUM_INCHANNELS({})' .format (
423
- num_branches , len (num_inchannels ))
424
- logger .error (error_msg )
425
- raise ValueError (error_msg )
426
-
427
- def _make_one_branch (self , branch_index , block , num_blocks , num_channels ,
428
- stride = 1 ):
418
+ def _make_one_branch (self , branch_index , block , num_blocks , num_channels , stride = 1 ):
429
419
downsample = None
430
- if stride != 1 or \
431
- self .num_inchannels [branch_index ] != num_channels [branch_index ] * block .expansion :
420
+ if stride != 1 or self .num_inchannels [branch_index ] != num_channels [branch_index ] * block .expansion :
432
421
downsample = nn .Sequential (
433
422
nn .Conv2d (
434
423
self .num_inchannels [branch_index ], num_channels [branch_index ] * block .expansion ,
@@ -489,22 +478,22 @@ def _make_fuse_layers(self):
489
478
def get_num_inchannels (self ):
490
479
return self .num_inchannels
491
480
492
- def forward (self , x ):
481
+ def forward (self , x : List [ torch . Tensor ] ):
493
482
if self .num_branches == 1 :
494
483
return [self .branches [0 ](x [0 ])]
495
484
496
- for i in range (self .num_branches ):
497
- x [i ] = self . branches [ i ] (x [i ])
485
+ for i , branch in enumerate (self .branches ):
486
+ x [i ] = branch (x [i ])
498
487
499
488
x_fuse = []
500
- for i in range ( len ( self .fuse_layers ) ):
501
- y = x [0 ] if i == 0 else self . fuse_layers [ i ] [0 ](x [0 ])
489
+ for i , fuse_outer in enumerate ( self .fuse_layers ):
490
+ y = x [0 ] if i == 0 else fuse_outer [0 ](x [0 ])
502
491
for j in range (1 , self .num_branches ):
503
492
if i == j :
504
493
y = y + x [j ]
505
494
else :
506
- y = y + self . fuse_layers [ i ] [j ](x [j ])
507
- x_fuse .append (self .relu (y ))
495
+ y = y + fuse_outer [j ](x [j ])
496
+ x_fuse .append (self .fuse_act (y ))
508
497
509
498
return x_fuse
510
499
@@ -517,17 +506,18 @@ def forward(self, x):
517
506
518
507
class HighResolutionNet (nn .Module ):
519
508
520
- def __init__ (self , cfg , in_chans = 3 , num_classes = 1000 , global_pool = 'avg' , drop_rate = 0.0 ):
509
+ def __init__ (self , cfg , in_chans = 3 , num_classes = 1000 , global_pool = 'avg' , drop_rate = 0.0 , head = 'classification' ):
521
510
super (HighResolutionNet , self ).__init__ ()
522
511
self .num_classes = num_classes
523
512
self .drop_rate = drop_rate
524
513
525
514
stem_width = cfg ['STEM_WIDTH' ]
526
515
self .conv1 = nn .Conv2d (in_chans , stem_width , kernel_size = 3 , stride = 2 , padding = 1 , bias = False )
527
516
self .bn1 = nn .BatchNorm2d (stem_width , momentum = _BN_MOMENTUM )
517
+ self .act1 = nn .ReLU (inplace = True )
528
518
self .conv2 = nn .Conv2d (stem_width , 64 , kernel_size = 3 , stride = 2 , padding = 1 , bias = False )
529
519
self .bn2 = nn .BatchNorm2d (64 , momentum = _BN_MOMENTUM )
530
- self .relu = nn .ReLU (inplace = True )
520
+ self .act2 = nn .ReLU (inplace = True )
531
521
532
522
self .stage1_cfg = cfg ['STAGE1' ]
533
523
num_channels = self .stage1_cfg ['NUM_CHANNELS' ][0 ]
@@ -557,31 +547,49 @@ def __init__(self, cfg, in_chans=3, num_classes=1000, global_pool='avg', drop_ra
557
547
self .transition3 = self ._make_transition_layer (pre_stage_channels , num_channels )
558
548
self .stage4 , pre_stage_channels = self ._make_stage (self .stage4_cfg , num_channels , multi_scale_output = True )
559
549
560
- # Classification Head
561
- self .num_features = 2048
562
- self .incre_modules , self .downsamp_modules , self .final_layer = self ._make_head (pre_stage_channels )
563
- self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
564
- self .classifier = nn .Linear (self .num_features * self .global_pool .feat_mult (), num_classes )
550
+ self .head = head
551
+ self .head_channels = None # set if _make_head called
552
+ if head == 'classification' :
553
+ # Classification Head
554
+ self .num_features = 2048
555
+ self .incre_modules , self .downsamp_modules , self .final_layer = self ._make_head (pre_stage_channels )
556
+ self .global_pool = SelectAdaptivePool2d (pool_type = global_pool )
557
+ self .classifier = nn .Linear (self .num_features * self .global_pool .feat_mult (), num_classes )
558
+ elif head == 'incre' :
559
+ self .num_features = 2048
560
+ self .incre_modules , _ , _ = self ._make_head (pre_stage_channels , True )
561
+ else :
562
+ self .incre_modules = None
563
+ self .num_features = 256
564
+
565
+ curr_stride = 2
566
+ # module names aren't actually valid here, hook or FeatureNet based extraction would not work
567
+ self .feature_info = [dict (num_chs = 64 , reduction = curr_stride , module = 'stem' )]
568
+ for i , c in enumerate (self .head_channels if self .head_channels else num_channels ):
569
+ curr_stride *= 2
570
+ c = c * 4 if self .head_channels else c # head block expansion factor of 4
571
+ self .feature_info += [dict (num_chs = c , reduction = curr_stride , module = f'stage{ i + 1 } ' )]
565
572
566
573
self .init_weights ()
567
574
568
- def _make_head (self , pre_stage_channels ):
575
+ def _make_head (self , pre_stage_channels , incre_only = False ):
569
576
head_block = Bottleneck
570
- head_channels = [32 , 64 , 128 , 256 ]
577
+ self . head_channels = [32 , 64 , 128 , 256 ]
571
578
572
579
# Increasing the #channels on each resolution
573
580
# from C, 2C, 4C, 8C to 128, 256, 512, 1024
574
581
incre_modules = []
575
582
for i , channels in enumerate (pre_stage_channels ):
576
- incre_modules .append (
577
- self ._make_layer (head_block , channels , head_channels [i ], 1 , stride = 1 ))
583
+ incre_modules .append (self ._make_layer (head_block , channels , self .head_channels [i ], 1 , stride = 1 ))
578
584
incre_modules = nn .ModuleList (incre_modules )
585
+ if incre_only :
586
+ return incre_modules , None , None
579
587
580
588
# downsampling modules
581
589
downsamp_modules = []
582
590
for i in range (len (pre_stage_channels ) - 1 ):
583
- in_channels = head_channels [i ] * head_block .expansion
584
- out_channels = head_channels [i + 1 ] * head_block .expansion
591
+ in_channels = self . head_channels [i ] * head_block .expansion
592
+ out_channels = self . head_channels [i + 1 ] * head_block .expansion
585
593
downsamp_module = nn .Sequential (
586
594
nn .Conv2d (
587
595
in_channels = in_channels , out_channels = out_channels , kernel_size = 3 , stride = 2 , padding = 1 ),
@@ -593,7 +601,7 @@ def _make_head(self, pre_stage_channels):
593
601
594
602
final_layer = nn .Sequential (
595
603
nn .Conv2d (
596
- in_channels = head_channels [3 ] * head_block .expansion ,
604
+ in_channels = self . head_channels [3 ] * head_block .expansion ,
597
605
out_channels = self .num_features , kernel_size = 1 , stride = 1 , padding = 0
598
606
),
599
607
nn .BatchNorm2d (self .num_features , momentum = _BN_MOMENTUM ),
@@ -655,11 +663,7 @@ def _make_stage(self, layer_config, num_inchannels, multi_scale_output=True):
655
663
modules = []
656
664
for i in range (num_modules ):
657
665
# multi_scale_output is only used last module
658
- if not multi_scale_output and i == num_modules - 1 :
659
- reset_multi_scale_output = False
660
- else :
661
- reset_multi_scale_output = True
662
-
666
+ reset_multi_scale_output = multi_scale_output or i < num_modules - 1
663
667
modules .append (HighResolutionModule (
664
668
num_branches , block , num_blocks , num_inchannels , num_channels , fuse_method , reset_multi_scale_output )
665
669
)
@@ -688,40 +692,35 @@ def reset_classifier(self, num_classes, global_pool='avg'):
688
692
else :
689
693
self .classifier = nn .Identity ()
690
694
695
+ def stages (self , x ) -> List [torch .Tensor ]:
696
+ x = self .layer1 (x )
697
+
698
+ xl = [t (x ) for i , t in enumerate (self .transition1 )]
699
+ yl = self .stage2 (xl )
700
+
701
+ xl = [t (yl [- 1 ]) if not isinstance (t , nn .Identity ) else yl [i ] for i , t in enumerate (self .transition2 )]
702
+ yl = self .stage3 (xl )
703
+
704
+ xl = [t (yl [- 1 ]) if not isinstance (t , nn .Identity ) else yl [i ] for i , t in enumerate (self .transition3 )]
705
+ yl = self .stage4 (xl )
706
+ return yl
707
+
691
708
def forward_features (self , x ):
709
+ # Stem
692
710
x = self .conv1 (x )
693
711
x = self .bn1 (x )
694
- x = self .relu (x )
712
+ x = self .act1 (x )
695
713
x = self .conv2 (x )
696
714
x = self .bn2 (x )
697
- x = self .relu (x )
698
- x = self .layer1 (x )
699
-
700
- x_list = []
701
- for i in range (len (self .transition1 )):
702
- x_list .append (self .transition1 [i ](x ))
703
- y_list = self .stage2 (x_list )
704
-
705
- x_list = []
706
- for i in range (len (self .transition2 )):
707
- if not isinstance (self .transition2 [i ], nn .Identity ):
708
- x_list .append (self .transition2 [i ](y_list [- 1 ]))
709
- else :
710
- x_list .append (y_list [i ])
711
- y_list = self .stage3 (x_list )
715
+ x = self .act2 (x )
712
716
713
- x_list = []
714
- for i in range (len (self .transition3 )):
715
- if not isinstance (self .transition3 [i ], nn .Identity ):
716
- x_list .append (self .transition3 [i ](y_list [- 1 ]))
717
- else :
718
- x_list .append (y_list [i ])
719
- y_list = self .stage4 (x_list )
717
+ # Stages
718
+ yl = self .stages (x )
720
719
721
720
# Classification Head
722
- y = self .incre_modules [0 ](y_list [0 ])
723
- for i in range ( len ( self .downsamp_modules ) ):
724
- y = self .incre_modules [i + 1 ](y_list [i + 1 ]) + self . downsamp_modules [ i ] (y )
721
+ y = self .incre_modules [0 ](yl [0 ])
722
+ for i , down in enumerate ( self .downsamp_modules ):
723
+ y = self .incre_modules [i + 1 ](yl [i + 1 ]) + down (y )
725
724
y = self .final_layer (y )
726
725
return y
727
726
@@ -734,10 +733,55 @@ def forward(self, x):
734
733
return x
735
734
736
735
736
+ class HighResolutionNetFeatures (HighResolutionNet ):
737
+ """HighResolutionNet feature extraction
738
+
739
+ The design of HRNet makes it easy to grab feature maps, this class provides a simple wrapper to do so.
740
+ It would be more complicated to use the FeatureNet helpers.
741
+
742
+ The `feature_location=incre` allows grabbing increased channel count features using part of the
743
+ classification head. If `feature_location=''` the default HRNet features are returned. First stem
744
+ conv is used for stride 2 features.
745
+ """
746
+
747
+ def __init__ (self , cfg , in_chans = 3 , num_classes = 1000 , global_pool = 'avg' , drop_rate = 0.0 ,
748
+ feature_location = 'incre' , out_indices = (0 , 1 , 2 , 3 , 4 )):
749
+ assert feature_location in ('incre' , '' )
750
+ super (HighResolutionNetFeatures , self ).__init__ (
751
+ cfg , in_chans = in_chans , num_classes = num_classes , global_pool = global_pool ,
752
+ drop_rate = drop_rate , head = feature_location )
753
+ self .feature_info = FeatureInfo (self .feature_info , out_indices )
754
+ self ._out_idx = {i for i in out_indices }
755
+
756
+ def forward_features (self , x ):
757
+ assert False , 'Not supported'
758
+
759
+ def forward (self , x ) -> List [torch .tensor ]:
760
+ out = []
761
+ x = self .conv1 (x )
762
+ x = self .bn1 (x )
763
+ x = self .act1 (x )
764
+ if 0 in self ._out_idx :
765
+ out .append (x )
766
+ x = self .conv2 (x )
767
+ x = self .bn2 (x )
768
+ x = self .act2 (x )
769
+ x = self .stages (x )
770
+ if self .incre_modules is not None :
771
+ x = [incre (f ) for f , incre in zip (x , self .incre_modules )]
772
+ for i , f in enumerate (x ):
773
+ if i + 1 in self ._out_idx :
774
+ out .append (f )
775
+ return out
776
+
777
+
737
778
def _create_hrnet (variant , pretrained , ** model_kwargs ):
738
- assert not model_kwargs .pop ('features_only' , False ) # feature extraction not figured out yet
779
+ model_cls = HighResolutionNet
780
+ if model_kwargs .pop ('features_only' , False ):
781
+ model_cls = HighResolutionNetFeatures
782
+
739
783
return build_model_with_cfg (
740
- HighResolutionNet , variant , pretrained , default_cfg = default_cfgs [variant ],
784
+ model_cls , variant , pretrained , default_cfg = default_cfgs [variant ],
741
785
model_cfg = cfg_cls [variant ], ** model_kwargs )
742
786
743
787
0 commit comments