Skip to content

Commit 67332fc

Browse files
committed
Add features_intermediate() support to coatnet, maxvit, swin* models. Refine feature interface. Start prep of new vit weights.
1 parent 62516d5 commit 67332fc

18 files changed

+541
-62
lines changed

tests/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@
5050
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
5151
FEAT_INTER_FILTERS = [
5252
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*',
53-
'cait_*', 'xcit_*', 'volo_*',
53+
'cait_*', 'xcit_*', 'volo_*', 'swin*', 'max*vit_*', 'coatne*t_*'
5454
]
5555

5656
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
@@ -392,9 +392,8 @@ def test_model_forward_features(model_name, batch_size):
392392
@pytest.mark.parametrize('batch_size', [1])
393393
def test_model_forward_intermediates_features(model_name, batch_size):
394394
"""Run a single forward pass with each model in feature extraction mode"""
395-
model = create_model(model_name, pretrained=False, features_only=True)
395+
model = create_model(model_name, pretrained=False, features_only=True, feature_cls='getter')
396396
model.eval()
397-
print(model.feature_info.out_indices)
398397
expected_channels = model.feature_info.channels()
399398
expected_reduction = model.feature_info.reduction()
400399

@@ -434,13 +433,14 @@ def test_model_forward_intermediates(model_name, batch_size):
434433
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
435434
if max(input_size) > MAX_FFEAT_SIZE:
436435
pytest.skip("Fixed input size model > limit.")
437-
output_fmt = getattr(model, 'output_fmt', 'NCHW')
436+
output_fmt = 'NCHW' # NOTE output_fmt determined by forward_intermediates() arg, not model attribute
438437
feat_axis = get_channel_dim(output_fmt)
439438
spatial_axis = get_spatial_dim(output_fmt)
440439
import math
441440

442441
output, intermediates = model.forward_intermediates(
443442
torch.randn((batch_size, *input_size)),
443+
output_fmt=output_fmt,
444444
)
445445
assert len(expected_channels) == len(intermediates)
446446
spatial_size = input_size[-2:]

timm/models/_builder.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44
from copy import deepcopy
5-
from typing import Optional, Dict, Callable, Any, Tuple
5+
from typing import Any, Callable, Dict, List, Optional, Tuple
66

77
from torch import nn as nn
88
from torch.hub import load_state_dict_from_url
@@ -359,15 +359,15 @@ def build_model_with_cfg(
359359
* pruning config / model adaptation
360360
361361
Args:
362-
model_cls (nn.Module): model class
363-
variant (str): model variant name
364-
pretrained (bool): load pretrained weights
365-
pretrained_cfg (dict): model's pretrained weight/task config
366-
model_cfg (Optional[Dict]): model's architecture config
367-
feature_cfg (Optional[Dict]: feature extraction adapter config
368-
pretrained_strict (bool): load pretrained weights strictly
369-
pretrained_filter_fn (Optional[Callable]): filter callable for pretrained weights
370-
kwargs_filter (Optional[Tuple]): kwargs to filter before passing to model
362+
model_cls: model class
363+
variant: model variant name
364+
pretrained: load pretrained weights
365+
pretrained_cfg: model's pretrained weight/task config
366+
model_cfg: model's architecture config
367+
feature_cfg: feature extraction adapter config
368+
pretrained_strict: load pretrained weights strictly
369+
pretrained_filter_fn: filter callable for pretrained weights
370+
kwargs_filter: kwargs to filter before passing to model
371371
**kwargs: model args passed through to model __init__
372372
"""
373373
pruned = kwargs.pop('pruned', False)
@@ -392,6 +392,8 @@ def build_model_with_cfg(
392392
feature_cfg.setdefault('out_indices', (0, 1, 2, 3, 4))
393393
if 'out_indices' in kwargs:
394394
feature_cfg['out_indices'] = kwargs.pop('out_indices')
395+
if 'feature_cls' in kwargs:
396+
feature_cfg['feature_cls'] = kwargs.pop('feature_cls')
395397

396398
# Instantiate the model
397399
if model_cfg is None:
@@ -418,24 +420,36 @@ def build_model_with_cfg(
418420

419421
# Wrap the model in a feature extraction module if enabled
420422
if features:
421-
feature_cls = FeatureListNet
422-
output_fmt = getattr(model, 'output_fmt', None)
423-
if output_fmt is not None:
424-
feature_cfg.setdefault('output_fmt', output_fmt)
423+
use_getter = False
425424
if 'feature_cls' in feature_cfg:
426425
feature_cls = feature_cfg.pop('feature_cls')
427426
if isinstance(feature_cls, str):
428427
feature_cls = feature_cls.lower()
428+
429+
# flatten_sequential only valid for some feature extractors
430+
if feature_cls not in ('dict', 'list', 'hook'):
431+
feature_cfg.pop('flatten_sequential', None)
432+
429433
if 'hook' in feature_cls:
430434
feature_cls = FeatureHookNet
435+
elif feature_cls == 'list':
436+
feature_cls = FeatureListNet
431437
elif feature_cls == 'dict':
432438
feature_cls = FeatureDictNet
433439
elif feature_cls == 'fx':
434440
feature_cls = FeatureGraphNet
435441
elif feature_cls == 'getter':
442+
use_getter = True
436443
feature_cls = FeatureGetterNet
437444
else:
438445
assert False, f'Unknown feature class {feature_cls}'
446+
else:
447+
feature_cls = FeatureListNet
448+
449+
output_fmt = getattr(model, 'output_fmt', None)
450+
if output_fmt is not None and not use_getter: # don't set default for intermediate feat getter
451+
feature_cfg.setdefault('output_fmt', output_fmt)
452+
439453
model = feature_cls(model, **feature_cfg)
440454
model.pretrained_cfg = pretrained_cfg_for_features(pretrained_cfg) # add back pretrained cfg
441455
model.default_cfg = model.pretrained_cfg # alias for rename backwards compat (default_cfg -> pretrained_cfg)

timm/models/_features.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ def __init__(
363363
out_map: Optional[Sequence[Union[int, str]]] = None,
364364
return_dict: bool = False,
365365
output_fmt: str = 'NCHW',
366-
no_rewrite: bool = False,
366+
no_rewrite: Optional[bool] = None,
367367
flatten_sequential: bool = False,
368368
default_hook_type: str = 'forward',
369369
):
@@ -385,7 +385,8 @@ def __init__(
385385
self.return_dict = return_dict
386386
self.output_fmt = Format(output_fmt)
387387
self.grad_checkpointing = False
388-
388+
if no_rewrite is None:
389+
no_rewrite = not flatten_sequential
389390
layers = OrderedDict()
390391
hooks = []
391392
if no_rewrite:
@@ -467,7 +468,7 @@ def __init__(
467468
self.out_indices = out_indices
468469
self.out_map = out_map
469470
self.return_dict = return_dict
470-
self.output_fmt = output_fmt
471+
self.output_fmt = Format(output_fmt)
471472
self.norm = norm
472473

473474
def forward(self, x):

timm/models/_features_fx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
has_fx_feature_extraction = False
1616

1717
# Layers we went to treat as leaf modules
18-
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame
18+
from timm.layers import Conv2dSame, ScaledStdConv2dSame, CondConv2d, StdConv2dSame, Format
1919
from timm.layers.non_local_attn import BilinearAttnTransform
2020
from timm.layers.pool2d_same import MaxPool2dSame, AvgPool2dSame
2121
from timm.layers.norm_act import (
@@ -108,12 +108,14 @@ def __init__(
108108
model: nn.Module,
109109
out_indices: Tuple[int, ...],
110110
out_map: Optional[Dict] = None,
111+
output_fmt: str = 'NCHW',
111112
):
112113
super().__init__()
113114
assert has_fx_feature_extraction, 'Please update to PyTorch 1.10+, torchvision 0.11+ for FX feature extraction'
114115
self.feature_info = _get_feature_info(model, out_indices)
115116
if out_map is not None:
116117
assert len(out_map) == len(out_indices)
118+
self.output_fmt = Format(output_fmt)
117119
return_nodes = _get_return_layers(self.feature_info, out_map)
118120
self.graph_module = create_feature_extractor(model, return_nodes)
119121

timm/models/beit.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -404,10 +404,11 @@ def reset_classifier(self, num_classes, global_pool=None):
404404
def forward_intermediates(
405405
self,
406406
x: torch.Tensor,
407+
*,
407408
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
408409
return_prefix_tokens: bool = False,
409410
norm: bool = False,
410-
stop_early: bool = True,
411+
stop_early: bool = False,
411412
output_fmt: str = 'NCHW',
412413
intermediates_only: bool = False,
413414
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@@ -469,13 +470,13 @@ def forward_intermediates(
469470

470471
def prune_intermediate_layers(
471472
self,
472-
n: Union[int, List[int], Tuple[int]] = 1,
473+
indices: Union[int, List[int], Tuple[int]] = 1,
473474
prune_norm: bool = False,
474475
prune_head: bool = True,
475476
):
476477
""" Prune layers not required for specified intermediates.
477478
"""
478-
take_indices, max_index = feature_take_indices(len(self.blocks), n)
479+
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
479480
self.blocks = self.blocks[:max_index + 1] # truncate blocks
480481
if prune_norm:
481482
self.norm = nn.Identity()

timm/models/cait.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -341,9 +341,10 @@ def reset_classifier(self, num_classes, global_pool=None):
341341
def forward_intermediates(
342342
self,
343343
x: torch.Tensor,
344+
*,
344345
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
345346
norm: bool = False,
346-
stop_early: bool = True,
347+
stop_early: bool = False,
347348
output_fmt: str = 'NCHW',
348349
intermediates_only: bool = False,
349350
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
@@ -397,13 +398,13 @@ def forward_intermediates(
397398

398399
def prune_intermediate_layers(
399400
self,
400-
n: Union[int, List[int], Tuple[int]] = 1,
401+
indices: Union[int, List[int], Tuple[int]] = 1,
401402
prune_norm: bool = False,
402403
prune_head: bool = True,
403404
):
404405
""" Prune layers not required for specified intermediates.
405406
"""
406-
take_indices, max_index = feature_take_indices(len(self.blocks), n)
407+
take_indices, max_index = feature_take_indices(len(self.blocks), indices)
407408
self.blocks = self.blocks[:max_index + 1] # truncate blocks
408409
if prune_norm:
409410
self.norm = nn.Identity()

0 commit comments

Comments
 (0)