Skip to content

Commit 45b7ae8

Browse files
committed
forward_intermediates() support for byob/byoanet models
1 parent c4b8897 commit 45b7ae8

File tree

2 files changed

+100
-7
lines changed

2 files changed

+100
-7
lines changed

tests/test_models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@
5151
FEAT_INTER_FILTERS = [
5252
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
5353
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
54-
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet'
54+
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
55+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer'
5556
]
5657

5758
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.

timm/models/byobnet.py

Lines changed: 98 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from timm.layers import ClassifierHead, ConvNormAct, BatchNormAct2d, DropPath, AvgPool2dSame, \
4141
create_conv2d, get_act_layer, get_norm_act_layer, get_attn, make_divisible, to_2tuple, EvoNorm2dS0a
4242
from ._builder import build_model_with_cfg
43+
from ._features import feature_take_indices
4344
from ._manipulate import named_apply, checkpoint_seq
4445
from ._registry import generate_default_cfgs, register_model
4546

@@ -948,25 +949,37 @@ def __init__(
948949
stem_norm_acts = [False] * (num_rep - num_act) + [True] * num_act
949950
prev_chs = in_chs
950951
curr_stride = 1
952+
last_feat_idx = -1
951953
for i, (ch, s, na) in enumerate(zip(stem_chs, stem_strides, stem_norm_acts)):
952954
layer_fn = layers.conv_norm_act if na else create_conv2d
953955
conv_name = f'conv{i + 1}'
954956
if i > 0 and s > 1:
955-
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
957+
last_feat_idx = i - 1
958+
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
956959
self.add_module(conv_name, layer_fn(prev_chs, ch, kernel_size=kernel_size, stride=s))
957960
prev_chs = ch
958961
curr_stride *= s
959962
prev_feat = conv_name
960963

961964
if pool and 'max' in pool.lower():
962-
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
965+
last_feat_idx = i
966+
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
963967
self.add_module('pool', nn.MaxPool2d(3, 2, 1))
964968
curr_stride *= 2
965969
prev_feat = 'pool'
966970

967-
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat))
971+
self.last_feat_idx = last_feat_idx if last_feat_idx >= 0 else None
972+
self.feature_info.append(dict(num_chs=prev_chs, reduction=curr_stride, module=prev_feat, stage=0))
968973
assert curr_stride == stride
969974

975+
def forward_intermediates(self, x) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
976+
intermediate: Optional[torch.Tensor] = None
977+
for i, m in enumerate(self):
978+
x = m(x)
979+
if self.last_feat_idx is not None and i == self.last_feat_idx:
980+
intermediate = x
981+
return x, intermediate
982+
970983

971984
def create_byob_stem(
972985
in_chs: int,
@@ -1008,7 +1021,7 @@ def create_byob_stem(
10081021
if isinstance(stem, Stem):
10091022
feature_info = [dict(f, module='.'.join([feat_prefix, f['module']])) for f in stem.feature_info]
10101023
else:
1011-
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix)]
1024+
feature_info = [dict(num_chs=out_chs, reduction=2, module=feat_prefix, stage=0)]
10121025
return stem, feature_info
10131026

10141027

@@ -1122,7 +1135,7 @@ def create_byob_stages(
11221135
feat_size = reduce_feat_size(feat_size, stride)
11231136

11241137
stages += [nn.Sequential(*blocks)]
1125-
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}')
1138+
prev_feat = dict(num_chs=prev_chs, reduction=net_stride, module=f'stages.{stage_idx}', stage=stage_idx + 1)
11261139

11271140
feature_info.append(prev_feat)
11281141
return nn.Sequential(*stages), feature_info
@@ -1198,6 +1211,7 @@ def __init__(
11981211
feat_size=feat_size,
11991212
)
12001213
self.feature_info.extend(stage_feat[:-1])
1214+
reduction = stage_feat[-1]['reduction']
12011215

12021216
prev_chs = stage_feat[-1]['num_chs']
12031217
if cfg.num_features:
@@ -1207,7 +1221,8 @@ def __init__(
12071221
self.num_features = prev_chs
12081222
self.final_conv = nn.Identity()
12091223
self.feature_info += [
1210-
dict(num_chs=self.num_features, reduction=stage_feat[-1]['reduction'], module='final_conv')]
1224+
dict(num_chs=self.num_features, reduction=reduction, module='final_conv', stage=len(self.stages))]
1225+
self.stage_ends = [f['stage'] for f in self.feature_info]
12111226

12121227
self.head = ClassifierHead(
12131228
self.num_features,
@@ -1241,6 +1256,83 @@ def get_classifier(self):
12411256
def reset_classifier(self, num_classes, global_pool='avg'):
12421257
self.head.reset(num_classes, global_pool)
12431258

1259+
def forward_intermediates(
1260+
self,
1261+
x: torch.Tensor,
1262+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
1263+
norm: bool = False,
1264+
stop_early: bool = False,
1265+
output_fmt: str = 'NCHW',
1266+
intermediates_only: bool = False,
1267+
exclude_final_conv: bool = False,
1268+
) -> Union[List[torch.Tensor], Tuple[torch.Tensor, List[torch.Tensor]]]:
1269+
""" Forward features that returns intermediates.
1270+
1271+
Args:
1272+
x: Input image tensor
1273+
indices: Take last n blocks if int, all if None, select matching indices if sequence
1274+
norm: Apply norm layer to compatible intermediates
1275+
stop_early: Stop iterating over blocks when last desired intermediate hit
1276+
output_fmt: Shape of intermediate feature outputs
1277+
intermediates_only: Only return intermediate features
1278+
exclude_final_conv: Exclude final_conv from last intermediate
1279+
Returns:
1280+
1281+
"""
1282+
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
1283+
intermediates = []
1284+
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
1285+
take_indices = [self.stage_ends[i] for i in take_indices]
1286+
max_index = self.stage_ends[max_index]
1287+
# forward pass
1288+
feat_idx = 0 # stem is index 0
1289+
if hasattr(self.stem, 'forward_intermediates'):
1290+
# returns last intermediate features in stem (before final stride in stride > 2 stems)
1291+
x, x_inter = self.stem.forward_intermediates(x)
1292+
else:
1293+
x, x_inter = self.stem(x), None
1294+
if feat_idx in take_indices:
1295+
intermediates.append(x if x_inter is None else x_inter)
1296+
last_idx = self.stage_ends[-1]
1297+
if torch.jit.is_scripting() or not stop_early: # can't slice blocks in torchscript
1298+
stages = self.stages
1299+
else:
1300+
stages = self.stages[:max_index]
1301+
for stage in stages:
1302+
feat_idx += 1
1303+
x = stage(x)
1304+
if not exclude_final_conv and feat_idx == last_idx:
1305+
# default feature_info for this model uses final_conv as the last feature output (if present)
1306+
x = self.final_conv(x)
1307+
if feat_idx in take_indices:
1308+
intermediates.append(x)
1309+
1310+
if intermediates_only:
1311+
return intermediates
1312+
1313+
if exclude_final_conv and feat_idx == last_idx:
1314+
x = self.final_conv(x)
1315+
1316+
return x, intermediates
1317+
1318+
def prune_intermediate_layers(
1319+
self,
1320+
indices: Union[int, List[int], Tuple[int]] = 1,
1321+
prune_norm: bool = False,
1322+
prune_head: bool = True,
1323+
):
1324+
""" Prune layers not required for specified intermediates.
1325+
"""
1326+
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
1327+
max_index = self.stage_ends[max_index]
1328+
self.stages = self.stages[:max_index] # truncate blocks w/ stem as idx 0
1329+
if max_index < self.stage_ends[-1]:
1330+
self.final_conv = nn.Identity()
1331+
if prune_head:
1332+
self.reset_classifier(0, '')
1333+
return take_indices
1334+
1335+
12441336
def forward_features(self, x):
12451337
x = self.stem(x)
12461338
if self.grad_checkpointing and not torch.jit.is_scripting():

0 commit comments

Comments
 (0)