Skip to content

Commit 01dd01b

Browse files
committed
forward_intermediates() for MlpMixer models and RegNet.
1 parent f8979d4 commit 01dd01b

File tree

2 files changed

+146
-4
lines changed

2 files changed

+146
-4
lines changed

timm/models/mlp_mixer.py

Lines changed: 77 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,15 @@
4040
"""
4141
import math
4242
from functools import partial
43+
from typing import List, Optional, Union, Tuple
4344

4445
import torch
4546
import torch.nn as nn
4647

4748
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
4849
from timm.layers import PatchEmbed, Mlp, GluMlp, GatedMlp, DropPath, lecun_normal_, to_2tuple
4950
from ._builder import build_model_with_cfg
51+
from ._features import feature_take_indices
5052
from ._manipulate import named_apply, checkpoint_seq
5153
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
5254

@@ -211,6 +213,7 @@ def __init__(
211213
embed_dim=embed_dim,
212214
norm_layer=norm_layer if stem_norm else None,
213215
)
216+
reduction = self.stem.feat_ratio() if hasattr(self.stem, 'feat_ratio') else patch_size
214217
# FIXME drop_path (stochastic depth scaling rule or all the same?)
215218
self.blocks = nn.Sequential(*[
216219
block_layer(
@@ -224,6 +227,8 @@ def __init__(
224227
drop_path=drop_path_rate,
225228
)
226229
for _ in range(num_blocks)])
230+
self.feature_info = [
231+
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(num_blocks)]
227232
self.norm = norm_layer(embed_dim)
228233
self.head_drop = nn.Dropout(drop_rate)
229234
self.head = nn.Linear(embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
@@ -257,6 +262,76 @@ def reset_classifier(self, num_classes, global_pool=None):
257262
self.global_pool = global_pool
258263
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
259264

265+
def forward_intermediates(
266+
self,
267+
x: torch.Tensor,
268+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
269+
norm: bool = False,
270+
stop_early: bool = False,
271+
output_fmt: str = 'NCHW',
272+
intermediates_only: bool = False,
273+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
274+
""" Forward features that returns intermediates.
275+
276+
Args:
277+
x: Input image tensor
278+
indices: Take last n blocks if int, all if None, select matching indices if sequence
279+
return_prefix_tokens: Return both prefix and spatial intermediate tokens
280+
norm: Apply norm layer to all intermediates
281+
stop_early: Stop iterating over blocks when last desired intermediate hit
282+
output_fmt: Shape of intermediate feature outputs
283+
intermediates_only: Only return intermediate features
284+
Returns:
285+
286+
"""
287+
assert output_fmt in ('NCHW', 'NLC'), 'Output format must be one of NCHW or NLC.'
288+
reshape = output_fmt == 'NCHW'
289+
intermediates = []
290+
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
291+
292+
# forward pass
293+
B, _, height, width = x.shape
294+
x = self.stem(x)
295+
296+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
297+
blocks = self.blocks
298+
else:
299+
blocks = self.blocks[:max_index + 1]
300+
for i, blk in enumerate(blocks):
301+
x = blk(x)
302+
if i in take_indices:
303+
# normalize intermediates with final norm layer if enabled
304+
intermediates.append(self.norm(x) if norm else x)
305+
306+
# process intermediates
307+
if reshape:
308+
# reshape to BCHW output format
309+
H, W = self.stem.dynamic_feat_size((height, width))
310+
intermediates = [y.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous() for y in intermediates]
311+
312+
if intermediates_only:
313+
return intermediates
314+
315+
x = self.norm(x)
316+
317+
return x, intermediates
318+
319+
def prune_intermediate_layers(
320+
self,
321+
indices: Union[int, List[int], Tuple[int]] = 1,
322+
prune_norm: bool = False,
323+
prune_head: bool = True,
324+
):
325+
""" Prune layers not required for specified intermediates.
326+
"""
327+
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
328+
self.blocks = self.blocks[:max_index + 1] # truncate blocks
329+
if prune_norm:
330+
self.norm = nn.Identity()
331+
if prune_head:
332+
self.reset_classifier(0, '')
333+
return take_indices
334+
260335
def forward_features(self, x):
261336
x = self.stem(x)
262337
if self.grad_checkpointing and not torch.jit.is_scripting():
@@ -330,14 +405,13 @@ def checkpoint_filter_fn(state_dict, model):
330405

331406

332407
def _create_mixer(variant, pretrained=False, **kwargs):
333-
if kwargs.get('features_only', None):
334-
raise RuntimeError('features_only not implemented for MLP-Mixer models.')
335-
408+
out_indices = kwargs.pop('out_indices', 3)
336409
model = build_model_with_cfg(
337410
MlpMixer,
338411
variant,
339412
pretrained,
340413
pretrained_filter_fn=checkpoint_filter_fn,
414+
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
341415
**kwargs,
342416
)
343417
return model

timm/models/regnet.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import math
2727
from dataclasses import dataclass, replace
2828
from functools import partial
29-
from typing import Optional, Union, Callable
29+
from typing import Callable, List, Optional, Union, Tuple
3030

3131
import numpy as np
3232
import torch
@@ -36,6 +36,7 @@
3636
from timm.layers import ClassifierHead, AvgPool2dSame, ConvNormAct, SEModule, DropPath, GroupNormAct
3737
from timm.layers import get_act_layer, get_norm_act_layer, create_conv2d, make_divisible
3838
from ._builder import build_model_with_cfg
39+
from ._features import feature_take_indices
3940
from ._manipulate import checkpoint_seq, named_apply
4041
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
4142

@@ -515,6 +516,73 @@ def get_classifier(self):
515516
def reset_classifier(self, num_classes, global_pool='avg'):
516517
self.head.reset(num_classes, pool_type=global_pool)
517518

519+
def forward_intermediates(
520+
self,
521+
x: torch.Tensor,
522+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
523+
norm: bool = False,
524+
stop_early: bool = False,
525+
output_fmt: str = 'NCHW',
526+
intermediates_only: bool = False,
527+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
528+
""" Forward features that returns intermediates.
529+
530+
Args:
531+
x: Input image tensor
532+
indices: Take last n blocks if int, all if None, select matching indices if sequence
533+
norm: Apply norm layer to compatible intermediates
534+
stop_early: Stop iterating over blocks when last desired intermediate hit
535+
output_fmt: Shape of intermediate feature outputs
536+
intermediates_only: Only return intermediate features
537+
Returns:
538+
539+
"""
540+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
541+
intermediates = []
542+
take_indices, max_index = feature_take_indices(5, indices)
543+
544+
# forward pass
545+
feat_idx = 0
546+
x = self.stem(x)
547+
if feat_idx in take_indices:
548+
intermediates.append(x)
549+
550+
layer_names = ('s1', 's2', 's3', 's4')
551+
if stop_early:
552+
layer_names = layer_names[:max_index]
553+
for n in layer_names:
554+
feat_idx += 1
555+
x = getattr(self, n)(x) # won't work with torchscript, but keeps code reasonable, FML
556+
if feat_idx in take_indices:
557+
intermediates.append(x)
558+
559+
if intermediates_only:
560+
return intermediates
561+
562+
if feat_idx == 4:
563+
x = self.final_conv(x)
564+
565+
return x, intermediates
566+
567+
def prune_intermediate_layers(
568+
self,
569+
indices: Union[int, List[int], Tuple[int]] = 1,
570+
prune_norm: bool = False,
571+
prune_head: bool = True,
572+
):
573+
""" Prune layers not required for specified intermediates.
574+
"""
575+
take_indices, max_index = feature_take_indices(5, indices)
576+
layer_names = ('s1', 's2', 's3', 's4')
577+
layer_names = layer_names[max_index:]
578+
for n in layer_names:
579+
setattr(self, n, nn.Identity())
580+
if max_index < 4:
581+
self.final_conv = nn.Identity()
582+
if prune_head:
583+
self.reset_classifier(0, '')
584+
return take_indices
585+
518586
def forward_features(self, x):
519587
x = self.stem(x)
520588
x = self.s1(x)

0 commit comments

Comments
 (0)