32
32
33
33
34
34
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
35
- from timm .layers import DropPath , Mlp , use_fused_attn , _assert
35
+ from timm .layers import DropPath , Mlp , use_fused_attn , _assert , get_norm_layer
36
36
37
37
38
38
from ._registry import generate_default_cfgs , register_model
@@ -372,20 +372,41 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
372
372
return x
373
373
374
374
375
- class Head (nn .Module ):
375
+ class NormClassifierHead (nn .Module ):
376
376
def __init__ (
377
377
self ,
378
- dim : int ,
378
+ in_features : int ,
379
379
num_classes : int ,
380
+ pool_type : str = 'avg' ,
380
381
drop_rate : float = 0.0 ,
382
+ norm_layer : Union [str , Callable ] = 'layernorm' ,
381
383
):
382
384
super ().__init__ ()
383
- self .dropout = nn .Dropout (drop_rate ) if drop_rate > 0 else nn .Identity ()
384
- self .projection = nn .Linear (dim , num_classes )
385
+ norm_layer = get_norm_layer (norm_layer )
386
+ assert pool_type in ('avg' , '' )
387
+ self .in_features = self .num_features = in_features
388
+ self .pool_type = pool_type
389
+ self .norm = norm_layer (in_features )
390
+ self .drop = nn .Dropout (drop_rate ) if drop_rate else nn .Identity ()
391
+ self .fc = nn .Linear (in_features , num_classes ) if num_classes > 0 else nn .Identity ()
392
+
393
+ def reset (self , num_classes : int , pool_type : Optional [str ] = None , other : bool = False ):
394
+ if pool_type is not None :
395
+ assert pool_type in ('avg' , '' )
396
+ self .pool_type = pool_type
397
+ if other :
398
+ # reset other non-fc layers
399
+ self .norm = nn .Identity ()
400
+ self .fc = nn .Linear (self .in_features , num_classes ) if num_classes > 0 else nn .Identity ()
385
401
386
- def forward (self , x : torch .Tensor ) -> torch .Tensor :
387
- x = self .dropout (x )
388
- x = self .projection (x )
402
+ def forward (self , x : torch .Tensor , pre_logits : bool = False ) -> torch .Tensor :
403
+ if self .pool_type == 'avg' :
404
+ x = x .mean (dim = 1 )
405
+ x = self .norm (x )
406
+ x = self .drop (x )
407
+ if pre_logits :
408
+ return x
409
+ x = self .fc (x )
389
410
return x
390
411
391
412
@@ -438,6 +459,7 @@ def __init__(
438
459
embed_dim : int = 96 , # initial embed dim
439
460
num_heads : int = 1 , # initial number of heads
440
461
num_classes : int = 1000 ,
462
+ global_pool : str = 'avg' ,
441
463
stages : Tuple [int , ...] = (2 , 3 , 16 , 3 ),
442
464
q_pool : int = 3 , # number of q_pool stages
443
465
q_stride : Tuple [int , ...] = (2 , 2 ),
@@ -458,11 +480,7 @@ def __init__(
458
480
):
459
481
super ().__init__ ()
460
482
self .num_classes = num_classes
461
-
462
- # Do it this way to ensure that the init args are all PoD (for config usage)
463
- if isinstance (norm_layer , str ):
464
- norm_layer = partial (getattr (nn , norm_layer ), eps = 1e-6 )
465
-
483
+ norm_layer = get_norm_layer (norm_layer )
466
484
depth = sum (stages )
467
485
self .patch_stride = patch_stride
468
486
self .tokens_spatial_shape = [i // s for i , s in zip (img_size , patch_stride )]
@@ -552,8 +570,14 @@ def __init__(
552
570
dict (num_chs = dim_out , reduction = 2 ** (cur_stage + 2 ), module = f'blocks.{ self .stage_ends [cur_stage ]} ' )]
553
571
self .blocks .append (block )
554
572
555
- self .norm = norm_layer (embed_dim )
556
- self .head = Head (embed_dim , num_classes , drop_rate = drop_rate )
573
+ self .num_features = embed_dim
574
+ self .head = NormClassifierHead (
575
+ embed_dim ,
576
+ num_classes ,
577
+ pool_type = global_pool ,
578
+ drop_rate = drop_rate ,
579
+ norm_layer = norm_layer ,
580
+ )
557
581
558
582
# Initialize everything
559
583
if sep_pos_embed :
@@ -562,8 +586,8 @@ def __init__(
562
586
else :
563
587
nn .init .trunc_normal_ (self .pos_embed , std = 0.02 )
564
588
self .apply (partial (self ._init_weights ))
565
- self .head .projection .weight .data .mul_ (head_init_scale )
566
- self .head .projection .bias .data .mul_ (head_init_scale )
589
+ self .head .fc .weight .data .mul_ (head_init_scale )
590
+ self .head .fc .bias .data .mul_ (head_init_scale )
567
591
568
592
def _init_weights (self , m , init_bias = 0.02 ):
569
593
if isinstance (m , (nn .Linear , nn .Conv1d , nn .Conv2d , nn .Conv3d )):
@@ -678,19 +702,17 @@ def forward_intermediates(
678
702
679
703
def prune_intermediate_layers (
680
704
self ,
681
- n : Union [int , List [int ], Tuple [int ]] = 1 ,
705
+ indices : Union [int , List [int ], Tuple [int ]] = 1 ,
682
706
prune_norm : bool = False ,
683
707
prune_head : bool = True ,
684
708
):
685
709
""" Prune layers not required for specified intermediates.
686
710
"""
687
- take_indices , max_index = feature_take_indices (len (self .stage_ends ), n )
711
+ take_indices , max_index = feature_take_indices (len (self .stage_ends ), indices )
688
712
max_index = self .stage_ends [max_index ]
689
713
self .blocks = self .blocks [:max_index + 1 ] # truncate blocks
690
714
if prune_head :
691
- # norm part of head for this model, equivalent to fc_norm in other vit.
692
- self .norm = nn .Identity ()
693
- self .head = nn .Identity ()
715
+ self .head .reset (0 , other = True )
694
716
return take_indices
695
717
696
718
@@ -732,11 +754,7 @@ def forward_features(
732
754
return x
733
755
734
756
def forward_head (self , x , pre_logits : bool = False ) -> torch .Tensor :
735
- x = x .mean (dim = 1 )
736
- x = self .norm (x )
737
- if pre_logits :
738
- return x
739
- x = self .head (x )
757
+ x = self .head (x , pre_logits = pre_logits ) if pre_logits else self .head (x )
740
758
return x
741
759
742
760
def forward (
@@ -756,7 +774,7 @@ def _cfg(url='', **kwargs):
756
774
'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
757
775
'crop_pct' : .9 , 'interpolation' : 'bicubic' , 'fixed_input_size' : True ,
758
776
'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD ,
759
- 'first_conv' : 'patch_embed.proj' , 'classifier' : 'head' ,
777
+ 'first_conv' : 'patch_embed.proj' , 'classifier' : 'head.fc ' ,
760
778
** kwargs
761
779
}
762
780
@@ -837,6 +855,10 @@ def checkpoint_filter_fn(state_dict, model=None):
837
855
# )
838
856
#v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2)
839
857
pass
858
+ if 'head.projection.' in k :
859
+ k = k .replace ('head.projection.' , 'head.fc.' )
860
+ if k .startswith ('norm.' ):
861
+ k = k .replace ('norm.' , 'head.norm.' )
840
862
output [k ] = v
841
863
return output
842
864
0 commit comments