40
40
"""
41
41
import math
42
42
from functools import partial
43
+ from typing import List , Optional , Union , Tuple
43
44
44
45
import torch
45
46
import torch .nn as nn
46
47
47
48
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
48
49
from timm .layers import PatchEmbed , Mlp , GluMlp , GatedMlp , DropPath , lecun_normal_ , to_2tuple
49
50
from ._builder import build_model_with_cfg
51
+ from ._features import feature_take_indices
50
52
from ._manipulate import named_apply , checkpoint_seq
51
53
from ._registry import generate_default_cfgs , register_model , register_model_deprecations
52
54
@@ -211,6 +213,7 @@ def __init__(
211
213
embed_dim = embed_dim ,
212
214
norm_layer = norm_layer if stem_norm else None ,
213
215
)
216
+ reduction = self .stem .feat_ratio () if hasattr (self .stem , 'feat_ratio' ) else patch_size
214
217
# FIXME drop_path (stochastic depth scaling rule or all the same?)
215
218
self .blocks = nn .Sequential (* [
216
219
block_layer (
@@ -224,6 +227,8 @@ def __init__(
224
227
drop_path = drop_path_rate ,
225
228
)
226
229
for _ in range (num_blocks )])
230
+ self .feature_info = [
231
+ dict (module = f'blocks.{ i } ' , num_chs = embed_dim , reduction = reduction ) for i in range (num_blocks )]
227
232
self .norm = norm_layer (embed_dim )
228
233
self .head_drop = nn .Dropout (drop_rate )
229
234
self .head = nn .Linear (embed_dim , self .num_classes ) if num_classes > 0 else nn .Identity ()
@@ -257,6 +262,76 @@ def reset_classifier(self, num_classes, global_pool=None):
257
262
self .global_pool = global_pool
258
263
self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
259
264
265
+ def forward_intermediates (
266
+ self ,
267
+ x : torch .Tensor ,
268
+ indices : Optional [Union [int , List [int ], Tuple [int ]]] = None ,
269
+ norm : bool = False ,
270
+ stop_early : bool = False ,
271
+ output_fmt : str = 'NCHW' ,
272
+ intermediates_only : bool = False ,
273
+ ) -> Union [List [torch .Tensor ], Tuple [torch .Tensor , List [torch .Tensor ]]]:
274
+ """ Forward features that returns intermediates.
275
+
276
+ Args:
277
+ x: Input image tensor
278
+ indices: Take last n blocks if int, all if None, select matching indices if sequence
279
+ return_prefix_tokens: Return both prefix and spatial intermediate tokens
280
+ norm: Apply norm layer to all intermediates
281
+ stop_early: Stop iterating over blocks when last desired intermediate hit
282
+ output_fmt: Shape of intermediate feature outputs
283
+ intermediates_only: Only return intermediate features
284
+ Returns:
285
+
286
+ """
287
+ assert output_fmt in ('NCHW' , 'NLC' ), 'Output format must be one of NCHW or NLC.'
288
+ reshape = output_fmt == 'NCHW'
289
+ intermediates = []
290
+ take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
291
+
292
+ # forward pass
293
+ B , _ , height , width = x .shape
294
+ x = self .stem (x )
295
+
296
+ if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
297
+ blocks = self .blocks
298
+ else :
299
+ blocks = self .blocks [:max_index + 1 ]
300
+ for i , blk in enumerate (blocks ):
301
+ x = blk (x )
302
+ if i in take_indices :
303
+ # normalize intermediates with final norm layer if enabled
304
+ intermediates .append (self .norm (x ) if norm else x )
305
+
306
+ # process intermediates
307
+ if reshape :
308
+ # reshape to BCHW output format
309
+ H , W = self .stem .dynamic_feat_size ((height , width ))
310
+ intermediates = [y .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous () for y in intermediates ]
311
+
312
+ if intermediates_only :
313
+ return intermediates
314
+
315
+ x = self .norm (x )
316
+
317
+ return x , intermediates
318
+
319
+ def prune_intermediate_layers (
320
+ self ,
321
+ indices : Union [int , List [int ], Tuple [int ]] = 1 ,
322
+ prune_norm : bool = False ,
323
+ prune_head : bool = True ,
324
+ ):
325
+ """ Prune layers not required for specified intermediates.
326
+ """
327
+ take_indices , max_index = feature_take_indices (len (self .blocks ), indices )
328
+ self .blocks = self .blocks [:max_index + 1 ] # truncate blocks
329
+ if prune_norm :
330
+ self .norm = nn .Identity ()
331
+ if prune_head :
332
+ self .reset_classifier (0 , '' )
333
+ return take_indices
334
+
260
335
def forward_features (self , x ):
261
336
x = self .stem (x )
262
337
if self .grad_checkpointing and not torch .jit .is_scripting ():
@@ -330,14 +405,13 @@ def checkpoint_filter_fn(state_dict, model):
330
405
331
406
332
407
def _create_mixer (variant , pretrained = False , ** kwargs ):
333
- if kwargs .get ('features_only' , None ):
334
- raise RuntimeError ('features_only not implemented for MLP-Mixer models.' )
335
-
408
+ out_indices = kwargs .pop ('out_indices' , 3 )
336
409
model = build_model_with_cfg (
337
410
MlpMixer ,
338
411
variant ,
339
412
pretrained ,
340
413
pretrained_filter_fn = checkpoint_filter_fn ,
414
+ feature_cfg = dict (out_indices = out_indices , feature_cls = 'getter' ),
341
415
** kwargs ,
342
416
)
343
417
return model
0 commit comments