Skip to content

Commit d6da4fb

Browse files
committed
Add forward_intermediates() to efficientnet / mobilenetv3 based models as an exercise.
1 parent c22efb9 commit d6da4fb

File tree

4 files changed

+182
-12
lines changed

4 files changed

+182
-12
lines changed

tests/test_models.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,9 @@
4949

5050
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
5151
FEAT_INTER_FILTERS = [
52-
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*',
53-
'cait_*', 'xcit_*', 'volo_*', 'swin*', 'max*vit_*', 'coatne*t_*'
52+
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
53+
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
54+
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3'
5455
]
5556

5657
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
@@ -388,7 +389,7 @@ def test_model_forward_features(model_name, batch_size):
388389

389390
@pytest.mark.features
390391
@pytest.mark.timeout(120)
391-
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
392+
@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
392393
@pytest.mark.parametrize('batch_size', [1])
393394
def test_model_forward_intermediates_features(model_name, batch_size):
394395
"""Run a single forward pass with each model in feature extraction mode"""
@@ -419,7 +420,7 @@ def test_model_forward_intermediates_features(model_name, batch_size):
419420

420421
@pytest.mark.features
421422
@pytest.mark.timeout(120)
422-
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
423+
@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
423424
@pytest.mark.parametrize('batch_size', [1])
424425
def test_model_forward_intermediates(model_name, batch_size):
425426
"""Run a single forward pass with each model in feature extraction mode"""

timm/models/_registry.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def _expand_filter(filter: str):
184184

185185
def list_models(
186186
filter: Union[str, List[str]] = '',
187-
module: str = '',
187+
module: Union[str, List[str]] = '',
188188
pretrained: bool = False,
189189
exclude_filters: Union[str, List[str]] = '',
190190
name_matches_cfg: bool = False,
@@ -217,7 +217,16 @@ def list_models(
217217
# FIXME should this be default behaviour? or default to include_tags=True?
218218
include_tags = pretrained
219219

220-
all_models: Set[str] = _module_to_models[module] if module else set(_model_entrypoints.keys())
220+
if not module:
221+
all_models: Set[str] = set(_model_entrypoints.keys())
222+
else:
223+
if isinstance(module, str):
224+
all_models: Set[str] = _module_to_models[module]
225+
else:
226+
assert isinstance(module, Sequence)
227+
all_models: Set[str] = set()
228+
for m in module:
229+
all_models.update(_module_to_models[m])
221230
all_models = all_models - _deprecated_models.keys() # remove deprecated models from listings
222231

223232
if include_tags:

timm/models/efficientnet.py

Lines changed: 84 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
Hacked together by / Copyright 2019, Ross Wightman
3737
"""
3838
from functools import partial
39-
from typing import List
39+
from typing import List, Optional, Tuple, Union
4040

4141
import torch
4242
import torch.nn as nn
@@ -49,7 +49,7 @@
4949
from ._efficientnet_blocks import SqueezeExcite
5050
from ._efficientnet_builder import EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
5151
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
52-
from ._features import FeatureInfo, FeatureHooks
52+
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
5353
from ._manipulate import checkpoint_seq
5454
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
5555

@@ -118,6 +118,7 @@ def __init__(
118118
)
119119
self.blocks = nn.Sequential(*builder(stem_size, block_args))
120120
self.feature_info = builder.features
121+
self.stage_ends = [f['stage'] for f in self.feature_info]
121122
head_chs = builder.in_chs
122123

123124
# Head + Pooling
@@ -158,6 +159,86 @@ def reset_classifier(self, num_classes, global_pool='avg'):
158159
self.global_pool, self.classifier = create_classifier(
159160
self.num_features, self.num_classes, pool_type=global_pool)
160161

162+
def forward_intermediates(
163+
self,
164+
x: torch.Tensor,
165+
*,
166+
indices: Union[int, List[int], Tuple[int]] = None,
167+
norm: bool = False,
168+
stop_early: bool = False,
169+
output_fmt: str = 'NCHW',
170+
intermediates_only: bool = False,
171+
extra_blocks: bool = False,
172+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
173+
""" Forward features that returns intermediates.
174+
175+
Args:
176+
x: Input image tensor
177+
indices: Take last n blocks if int, all if None, select matching indices if sequence
178+
norm: Apply norm layer to compatible intermediates
179+
stop_early: Stop iterating over blocks when last desired intermediate hit
180+
output_fmt: Shape of intermediate feature outputs
181+
intermediates_only: Only return intermediate features
182+
extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info
183+
Returns:
184+
185+
"""
186+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
187+
if stop_early:
188+
assert intermediates_only, 'Must use intermediates_only for early stopping.'
189+
intermediates = []
190+
if extra_blocks:
191+
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
192+
else:
193+
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
194+
take_indices = [self.stage_ends[i] for i in take_indices]
195+
max_index = self.stage_ends[max_index]
196+
# forward pass
197+
feat_idx = 0 # stem is index 0
198+
x = self.conv_stem(x)
199+
x = self.bn1(x)
200+
if feat_idx in take_indices:
201+
intermediates.append(x)
202+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
203+
blocks = self.blocks
204+
else:
205+
blocks = self.blocks[:max_index]
206+
for blk in blocks:
207+
feat_idx += 1
208+
x = blk(x)
209+
if feat_idx in take_indices:
210+
intermediates.append(x)
211+
212+
if intermediates_only:
213+
return intermediates
214+
215+
x = self.conv_head(x)
216+
x = self.bn2(x)
217+
218+
return x, intermediates
219+
220+
def prune_intermediate_layers(
221+
self,
222+
indices: Union[int, List[int], Tuple[int]] = 1,
223+
prune_norm: bool = False,
224+
prune_head: bool = True,
225+
extra_blocks: bool = False,
226+
):
227+
""" Prune layers not required for specified intermediates.
228+
"""
229+
if extra_blocks:
230+
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
231+
else:
232+
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
233+
max_index = self.stage_ends[max_index]
234+
self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
235+
if prune_norm or max_index < len(self.blocks):
236+
self.conv_head = nn.Identity()
237+
self.bn2 = nn.Identity()
238+
if prune_head:
239+
self.reset_classifier(0, '')
240+
return take_indices
241+
161242
def forward_features(self, x):
162243
x = self.conv_stem(x)
163244
x = self.bn1(x)
@@ -272,7 +353,7 @@ def _create_effnet(variant, pretrained=False, **kwargs):
272353
model_cls = EfficientNet
273354
kwargs_filter = None
274355
if kwargs.pop('features_only', False):
275-
if 'feature_cfg' in kwargs:
356+
if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
276357
features_mode = 'cfg'
277358
else:
278359
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'global_pool')

timm/models/mobilenetv3.py

Lines changed: 82 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
Hacked together by / Copyright 2019, Ross Wightman
88
"""
99
from functools import partial
10-
from typing import Callable, List, Optional, Tuple
10+
from typing import Callable, List, Optional, Tuple, Union
1111

1212
import torch
1313
import torch.nn as nn
@@ -20,7 +20,7 @@
2020
from ._efficientnet_blocks import SqueezeExcite
2121
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
2222
round_channels, resolve_bn_args, resolve_act_layer, BN_EPS_TF_DEFAULT
23-
from ._features import FeatureInfo, FeatureHooks
23+
from ._features import FeatureInfo, FeatureHooks, feature_take_indices
2424
from ._manipulate import checkpoint_seq
2525
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
2626

@@ -109,6 +109,7 @@ def __init__(
109109
)
110110
self.blocks = nn.Sequential(*builder(stem_size, block_args))
111111
self.feature_info = builder.features
112+
self.stage_ends = [f['stage'] for f in self.feature_info]
112113
head_chs = builder.in_chs
113114

114115
# Head + Pooling
@@ -150,6 +151,84 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
150151
self.flatten = nn.Flatten(1) if global_pool else nn.Identity() # don't flatten if pooling disabled
151152
self.classifier = Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
152153

154+
def forward_intermediates(
155+
self,
156+
x: torch.Tensor,
157+
*,
158+
indices: Union[int, List[int], Tuple[int]] = None,
159+
norm: bool = False,
160+
stop_early: bool = False,
161+
output_fmt: str = 'NCHW',
162+
intermediates_only: bool = False,
163+
extra_blocks: bool = False,
164+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
165+
""" Forward features that returns intermediates.
166+
167+
Args:
168+
x: Input image tensor
169+
indices: Take last n blocks if int, all if None, select matching indices if sequence
170+
norm: Apply norm layer to compatible intermediates
171+
stop_early: Stop iterating over blocks when last desired intermediate hit
172+
output_fmt: Shape of intermediate feature outputs
173+
intermediates_only: Only return intermediate features
174+
extra_blocks: Include outputs of all blocks and head conv in output, does not align with feature_info
175+
Returns:
176+
177+
"""
178+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
179+
if stop_early:
180+
assert intermediates_only, 'Must use intermediates_only for early stopping.'
181+
intermediates = []
182+
if extra_blocks:
183+
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
184+
else:
185+
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
186+
print(take_indices, self.stage_ends)
187+
take_indices = [self.stage_ends[i] for i in take_indices]
188+
max_index = self.stage_ends[max_index]
189+
# forward pass
190+
feat_idx = 0 # stem is index 0
191+
x = self.conv_stem(x)
192+
x = self.bn1(x)
193+
if feat_idx in take_indices:
194+
intermediates.append(x)
195+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
196+
blocks = self.blocks
197+
else:
198+
blocks = self.blocks[:max_index]
199+
for blk in blocks:
200+
feat_idx += 1
201+
x = blk(x)
202+
if feat_idx in take_indices:
203+
intermediates.append(x)
204+
205+
if intermediates_only:
206+
return intermediates
207+
208+
return x, intermediates
209+
210+
def prune_intermediate_layers(
211+
self,
212+
indices: Union[int, List[int], Tuple[int]] = 1,
213+
prune_norm: bool = False,
214+
prune_head: bool = True,
215+
extra_blocks: bool = False,
216+
):
217+
""" Prune layers not required for specified intermediates.
218+
"""
219+
if extra_blocks:
220+
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
221+
else:
222+
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
223+
max_index = self.stage_ends[max_index]
224+
self.blocks = self.blocks[:max_index] # truncate blocks w/ stem as idx 0
225+
if max_index < len(self.blocks):
226+
self.conv_head = nn.Identity()
227+
if prune_head:
228+
self.conv_head = nn.Identity()
229+
self.reset_classifier(0, '')
230+
return take_indices
231+
153232
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
154233
x = self.conv_stem(x)
155234
x = self.bn1(x)
@@ -288,7 +367,7 @@ def _create_mnv3(variant: str, pretrained: bool = False, **kwargs) -> MobileNetV
288367
model_cls = MobileNetV3
289368
kwargs_filter = None
290369
if kwargs.pop('features_only', False):
291-
if 'feature_cfg' in kwargs:
370+
if 'feature_cfg' in kwargs or 'feature_cls' in kwargs:
292371
features_mode = 'cfg'
293372
else:
294373
kwargs_filter = ('num_classes', 'num_features', 'head_conv', 'head_bias', 'global_pool')

0 commit comments

Comments
 (0)