40
40
from timm .layers import ClassifierHead , ConvNormAct , BatchNormAct2d , DropPath , AvgPool2dSame , \
41
41
create_conv2d , get_act_layer , get_norm_act_layer , get_attn , make_divisible , to_2tuple , EvoNorm2dS0a
42
42
from ._builder import build_model_with_cfg
43
+ from ._features import feature_take_indices
43
44
from ._manipulate import named_apply , checkpoint_seq
44
45
from ._registry import generate_default_cfgs , register_model
45
46
@@ -948,25 +949,37 @@ def __init__(
948
949
stem_norm_acts = [False ] * (num_rep - num_act ) + [True ] * num_act
949
950
prev_chs = in_chs
950
951
curr_stride = 1
952
+ last_feat_idx = - 1
951
953
for i , (ch , s , na ) in enumerate (zip (stem_chs , stem_strides , stem_norm_acts )):
952
954
layer_fn = layers .conv_norm_act if na else create_conv2d
953
955
conv_name = f'conv{ i + 1 } '
954
956
if i > 0 and s > 1 :
955
- self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat ))
957
+ last_feat_idx = i - 1
958
+ self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat , stage = 0 ))
956
959
self .add_module (conv_name , layer_fn (prev_chs , ch , kernel_size = kernel_size , stride = s ))
957
960
prev_chs = ch
958
961
curr_stride *= s
959
962
prev_feat = conv_name
960
963
961
964
if pool and 'max' in pool .lower ():
962
- self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat ))
965
+ last_feat_idx = i
966
+ self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat , stage = 0 ))
963
967
self .add_module ('pool' , nn .MaxPool2d (3 , 2 , 1 ))
964
968
curr_stride *= 2
965
969
prev_feat = 'pool'
966
970
967
- self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat ))
971
+ self .last_feat_idx = last_feat_idx if last_feat_idx >= 0 else None
972
+ self .feature_info .append (dict (num_chs = prev_chs , reduction = curr_stride , module = prev_feat , stage = 0 ))
968
973
assert curr_stride == stride
969
974
975
+ def forward_intermediates (self , x ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
976
+ intermediate : Optional [torch .Tensor ] = None
977
+ for i , m in enumerate (self ):
978
+ x = m (x )
979
+ if self .last_feat_idx is not None and i == self .last_feat_idx :
980
+ intermediate = x
981
+ return x , intermediate
982
+
970
983
971
984
def create_byob_stem (
972
985
in_chs : int ,
@@ -1008,7 +1021,7 @@ def create_byob_stem(
1008
1021
if isinstance (stem , Stem ):
1009
1022
feature_info = [dict (f , module = '.' .join ([feat_prefix , f ['module' ]])) for f in stem .feature_info ]
1010
1023
else :
1011
- feature_info = [dict (num_chs = out_chs , reduction = 2 , module = feat_prefix )]
1024
+ feature_info = [dict (num_chs = out_chs , reduction = 2 , module = feat_prefix , stage = 0 )]
1012
1025
return stem , feature_info
1013
1026
1014
1027
@@ -1122,7 +1135,7 @@ def create_byob_stages(
1122
1135
feat_size = reduce_feat_size (feat_size , stride )
1123
1136
1124
1137
stages += [nn .Sequential (* blocks )]
1125
- prev_feat = dict (num_chs = prev_chs , reduction = net_stride , module = f'stages.{ stage_idx } ' )
1138
+ prev_feat = dict (num_chs = prev_chs , reduction = net_stride , module = f'stages.{ stage_idx } ' , stage = stage_idx + 1 )
1126
1139
1127
1140
feature_info .append (prev_feat )
1128
1141
return nn .Sequential (* stages ), feature_info
@@ -1198,6 +1211,7 @@ def __init__(
1198
1211
feat_size = feat_size ,
1199
1212
)
1200
1213
self .feature_info .extend (stage_feat [:- 1 ])
1214
+ reduction = stage_feat [- 1 ]['reduction' ]
1201
1215
1202
1216
prev_chs = stage_feat [- 1 ]['num_chs' ]
1203
1217
if cfg .num_features :
@@ -1207,7 +1221,8 @@ def __init__(
1207
1221
self .num_features = prev_chs
1208
1222
self .final_conv = nn .Identity ()
1209
1223
self .feature_info += [
1210
- dict (num_chs = self .num_features , reduction = stage_feat [- 1 ]['reduction' ], module = 'final_conv' )]
1224
+ dict (num_chs = self .num_features , reduction = reduction , module = 'final_conv' , stage = len (self .stages ))]
1225
+ self .stage_ends = [f ['stage' ] for f in self .feature_info ]
1211
1226
1212
1227
self .head = ClassifierHead (
1213
1228
self .num_features ,
@@ -1241,6 +1256,83 @@ def get_classifier(self):
1241
1256
def reset_classifier (self , num_classes , global_pool = 'avg' ):
1242
1257
self .head .reset (num_classes , global_pool )
1243
1258
1259
+ def forward_intermediates (
1260
+ self ,
1261
+ x : torch .Tensor ,
1262
+ indices : Optional [Union [int , List [int ], Tuple [int ]]] = None ,
1263
+ norm : bool = False ,
1264
+ stop_early : bool = False ,
1265
+ output_fmt : str = 'NCHW' ,
1266
+ intermediates_only : bool = False ,
1267
+ exclude_final_conv : bool = False ,
1268
+ ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
1269
+ """ Forward features that returns intermediates.
1270
+
1271
+ Args:
1272
+ x: Input image tensor
1273
+ indices: Take last n blocks if int, all if None, select matching indices if sequence
1274
+ norm: Apply norm layer to compatible intermediates
1275
+ stop_early: Stop iterating over blocks when last desired intermediate hit
1276
+ output_fmt: Shape of intermediate feature outputs
1277
+ intermediates_only: Only return intermediate features
1278
+ exclude_final_conv: Exclude final_conv from last intermediate
1279
+ Returns:
1280
+
1281
+ """
1282
+ assert output_fmt in ('NCHW' ,), 'Output shape must be NCHW.'
1283
+ intermediates = []
1284
+ take_indices , max_index = feature_take_indices (len (self .stage_ends ), indices )
1285
+ take_indices = [self .stage_ends [i ] for i in take_indices ]
1286
+ max_index = self .stage_ends [max_index ]
1287
+ # forward pass
1288
+ feat_idx = 0 # stem is index 0
1289
+ if hasattr (self .stem , 'forward_intermediates' ):
1290
+ # returns last intermediate features in stem (before final stride in stride > 2 stems)
1291
+ x , x_inter = self .stem .forward_intermediates (x )
1292
+ else :
1293
+ x , x_inter = self .stem (x ), None
1294
+ if feat_idx in take_indices :
1295
+ intermediates .append (x if x_inter is None else x_inter )
1296
+ last_idx = self .stage_ends [- 1 ]
1297
+ if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
1298
+ stages = self .stages
1299
+ else :
1300
+ stages = self .stages [:max_index ]
1301
+ for stage in stages :
1302
+ feat_idx += 1
1303
+ x = stage (x )
1304
+ if not exclude_final_conv and feat_idx == last_idx :
1305
+ # default feature_info for this model uses final_conv as the last feature output (if present)
1306
+ x = self .final_conv (x )
1307
+ if feat_idx in take_indices :
1308
+ intermediates .append (x )
1309
+
1310
+ if intermediates_only :
1311
+ return intermediates
1312
+
1313
+ if exclude_final_conv and feat_idx == last_idx :
1314
+ x = self .final_conv (x )
1315
+
1316
+ return x , intermediates
1317
+
1318
+ def prune_intermediate_layers (
1319
+ self ,
1320
+ indices : Union [int , List [int ], Tuple [int ]] = 1 ,
1321
+ prune_norm : bool = False ,
1322
+ prune_head : bool = True ,
1323
+ ):
1324
+ """ Prune layers not required for specified intermediates.
1325
+ """
1326
+ take_indices , max_index = feature_take_indices (len (self .stage_ends ), indices )
1327
+ max_index = self .stage_ends [max_index ]
1328
+ self .stages = self .stages [:max_index ] # truncate blocks w/ stem as idx 0
1329
+ if max_index < self .stage_ends [- 1 ]:
1330
+ self .final_conv = nn .Identity ()
1331
+ if prune_head :
1332
+ self .reset_classifier (0 , '' )
1333
+ return take_indices
1334
+
1335
+
1244
1336
def forward_features (self , x ):
1245
1337
x = self .stem (x )
1246
1338
if self .grad_checkpointing and not torch .jit .is_scripting ():
0 commit comments