24
24
# --------------------------------------------------------
25
25
import math
26
26
from functools import partial
27
- from typing import List , Tuple , Type , Callable , Optional , Union
27
+ from typing import Callable , Dict , List , Optional , Tuple , Type , Union
28
28
29
29
import torch
30
30
import torch .nn as nn
31
31
import torch .nn .functional as F
32
+ from torch .utils .checkpoint import checkpoint
32
33
33
34
34
35
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
@@ -480,14 +481,14 @@ def __init__(
480
481
):
481
482
super ().__init__ ()
482
483
self .num_classes = num_classes
484
+ self .grad_checkpointing = False
483
485
norm_layer = get_norm_layer (norm_layer )
484
- depth = sum ( stages )
486
+
485
487
self .patch_stride = patch_stride
486
488
self .tokens_spatial_shape = [i // s for i , s in zip (img_size , patch_stride )]
487
489
num_tokens = math .prod (self .tokens_spatial_shape )
488
490
flat_mu_size = math .prod (mask_unit_size )
489
491
flat_q_stride = math .prod (q_stride )
490
-
491
492
assert q_pool < len (stages )
492
493
self .q_pool , self .q_stride = q_pool , q_stride
493
494
self .mu_size , self .mask_unit_size = flat_mu_size , mask_unit_size
@@ -532,11 +533,10 @@ def __init__(
532
533
# q_pool locations
533
534
q_pool_blocks = [x + 1 for x in self .stage_ends [:q_pool ]]
534
535
535
- # stochastic depth decay rule
536
- dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )]
537
-
538
536
# Transformer blocks
539
537
cur_stage = 0
538
+ depth = sum (stages )
539
+ dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )] # stochastic depth decay rule
540
540
self .blocks = nn .ModuleList ()
541
541
self .feature_info = []
542
542
for i in range (depth ):
@@ -586,8 +586,9 @@ def __init__(
586
586
else :
587
587
nn .init .trunc_normal_ (self .pos_embed , std = 0.02 )
588
588
self .apply (partial (self ._init_weights ))
589
- self .head .fc .weight .data .mul_ (head_init_scale )
590
- self .head .fc .bias .data .mul_ (head_init_scale )
589
+ if isinstance (self .head .fc , nn .Linear ):
590
+ self .head .fc .weight .data .mul_ (head_init_scale )
591
+ self .head .fc .bias .data .mul_ (head_init_scale )
591
592
592
593
def _init_weights (self , m , init_bias = 0.02 ):
593
594
if isinstance (m , (nn .Linear , nn .Conv1d , nn .Conv2d , nn .Conv3d )):
@@ -605,6 +606,25 @@ def no_weight_decay(self):
605
606
else :
606
607
return ["pos_embed_spatial" , "pos_embed_temporal" ]
607
608
609
+ @torch .jit .ignore
610
+ def group_matcher (self , coarse : bool = False ) -> Dict :
611
+ return dict (
612
+ stem = r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed' , # stem and embed
613
+ blocks = [(r'^blocks\.(\d+)' , None ), (r'^norm' , (99999 ,))]
614
+ )
615
+
616
+ @torch .jit .ignore
617
+ def set_grad_checkpointing (self , enable : bool = True ) -> None :
618
+ self .grad_checkpointing = enable
619
+
620
+ @torch .jit .ignore
621
+ def get_classifier (self ):
622
+ return self .head .fc
623
+
624
+ def reset_classifier (self , num_classes : int , global_pool : Optional [str ] = None , other : bool = False ):
625
+ self .num_classes = num_classes
626
+ self .head .reset (num_classes , global_pool , other = other )
627
+
608
628
def get_random_mask (self , x : torch .Tensor , mask_ratio : float ) -> torch .Tensor :
609
629
"""
610
630
Generates a random mask, mask_ratio fraction are dropped.
@@ -740,7 +760,10 @@ def forward_features(
740
760
741
761
intermediates = []
742
762
for i , blk in enumerate (self .blocks ):
743
- x = blk (x )
763
+ if self .grad_checkpointing and not torch .jit .is_scripting ():
764
+ x = checkpoint (blk , x )
765
+ else :
766
+ x = blk (x )
744
767
if return_intermediates and i in self .stage_ends :
745
768
intermediates .append (self .reroll (x , i , mask = mask ))
746
769
0 commit comments