Skip to content

Commit a5a2ad2

Browse files
committed
Fix consistency, testing for forward_head w/ pre_logits, reset_classifier, models with pre_logits size != unpooled feature size
* add test that model supports forward_head(x, pre_logits=True) * add head_hidden_size attr to all models and set differently from num_features attr when head has hidden layers * test forward_features() feat dim == model.num_features and pre_logits feat dim == self.head_hidden_size * more consistency in reset_classifier signature, add typing * asserts in some heads where pooling cannot be disabled Fix #2194
1 parent 4535a54 commit a5a2ad2

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

73 files changed

+404
-395
lines changed

tests/test_models.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,18 @@ def test_model_backward(model_name, batch_size):
169169
assert not torch.isnan(outputs).any(), 'Output included NaNs'
170170

171171

172+
# models with extra conv/linear layers after pooling
173+
EARLY_POOL_MODELS = (
174+
timm.models.EfficientVit,
175+
timm.models.EfficientVitLarge,
176+
timm.models.HighPerfGpuNet,
177+
timm.models.GhostNet,
178+
timm.models.MetaNeXt, # InceptionNeXt
179+
timm.models.MobileNetV3,
180+
timm.models.RepGhostNet,
181+
timm.models.VGG,
182+
)
183+
172184
@pytest.mark.cfg
173185
@pytest.mark.timeout(timeout300)
174186
@pytest.mark.parametrize('model_name', list_models(
@@ -179,6 +191,9 @@ def test_model_default_cfgs(model_name, batch_size):
179191
model = create_model(model_name, pretrained=False)
180192
model.eval()
181193
model.to(torch_device)
194+
assert getattr(model, 'num_classes') >= 0
195+
assert getattr(model, 'num_features') > 0
196+
assert getattr(model, 'head_hidden_size') > 0
182197
state_dict = model.state_dict()
183198
cfg = model.default_cfg
184199

@@ -195,37 +210,37 @@ def test_model_default_cfgs(model_name, batch_size):
195210
input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
196211
input_tensor = torch.randn((batch_size, *input_size), device=torch_device)
197212

198-
# test forward_features (always unpooled)
213+
# test forward_features (always unpooled) & forward_head w/ pre_logits
199214
outputs = model.forward_features(input_tensor)
200-
assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config'
201-
assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config'
202-
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
203-
assert outputs.shape[feat_axis] == model.num_features
215+
outputs_pre = model.forward_head(outputs, pre_logits=True)
216+
assert outputs.shape[spatial_axis[0]] == pool_size[0], f'unpooled feature shape {outputs.shape} != config'
217+
assert outputs.shape[spatial_axis[1]] == pool_size[1], f'unpooled feature shape {outputs.shape} != config'
218+
assert outputs.shape[feat_axis] == model.num_features, f'unpooled feature dim {outputs.shape[feat_axis]} != model.num_features {model.num_features}'
219+
assert outputs_pre.shape[1] == model.head_hidden_size, f'pre_logits feature dim {outputs_pre.shape[1]} != model.head_hidden_size {model.head_hidden_size}'
204220

205221
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
206222
model.reset_classifier(0)
207223
model.to(torch_device)
208224
outputs = model.forward(input_tensor)
209225
assert len(outputs.shape) == 2
210-
assert outputs.shape[1] == model.num_features
226+
assert outputs.shape[1] == model.head_hidden_size, f'feature dim w/ removed classifier {outputs.shape[1]} != model.head_hidden_size {model.head_hidden_size}'
227+
assert outputs.shape == outputs_pre.shape, f'output shape of pre_logits {outputs_pre.shape} does not match reset_head(0) {outputs.shape}'
211228

212-
# test model forward without pooling and classifier
213-
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
214-
model.to(torch_device)
215-
outputs = model.forward(input_tensor)
216-
assert len(outputs.shape) == 4
217-
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
218-
# mobilenetv3/ghostnet/repghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
229+
# test model forward after removing pooling and classifier
230+
if not isinstance(model, EARLY_POOL_MODELS):
231+
model.reset_classifier(0, '') # reset classifier and disable global pooling
232+
model.to(torch_device)
233+
outputs = model.forward(input_tensor)
234+
assert len(outputs.shape) == 4
219235
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
220236

221-
if 'pruned' not in model_name: # FIXME better pruned model handling
222-
# test classifier + global pool deletion via __init__
237+
# test classifier + global pool deletion via __init__
238+
if 'pruned' not in model_name and not isinstance(model, EARLY_POOL_MODELS):
223239
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
224240
model.to(torch_device)
225241
outputs = model.forward(input_tensor)
226242
assert len(outputs.shape) == 4
227-
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
228-
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
243+
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
229244

230245
# check classifier name matches default_cfg
231246
if cfg.get('num_classes', None):
@@ -253,6 +268,9 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
253268
model = create_model(model_name, pretrained=False)
254269
model.eval()
255270
model.to(torch_device)
271+
assert getattr(model, 'num_classes') >= 0
272+
assert getattr(model, 'num_features') > 0
273+
assert getattr(model, 'head_hidden_size') > 0
256274
state_dict = model.state_dict()
257275
cfg = model.default_cfg
258276

@@ -264,13 +282,15 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
264282
feat_dim = getattr(model, 'feature_dim', None)
265283

266284
outputs = model.forward_features(input_tensor)
285+
outputs_pre = model.forward_head(outputs, pre_logits=True)
267286
if isinstance(outputs, (tuple, list)):
268287
# cannot currently verify multi-tensor output.
269288
pass
270289
else:
271290
if feat_dim is None:
272291
feat_dim = -1 if outputs.ndim == 3 else 1
273292
assert outputs.shape[feat_dim] == model.num_features
293+
assert outputs_pre.shape[1] == model.head_hidden_size
274294

275295
# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
276296
model.reset_classifier(0)
@@ -280,7 +300,8 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
280300
outputs = outputs[0]
281301
if feat_dim is None:
282302
feat_dim = -1 if outputs.ndim == 3 else 1
283-
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
303+
assert outputs.shape[feat_dim] == model.head_hidden_size, 'pooled num_features != config'
304+
assert outputs.shape == outputs_pre.shape
284305

285306
model = create_model(model_name, pretrained=False, num_classes=0).eval()
286307
model.to(torch_device)

timm/models/_prune.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ def adapt_model_from_string(parent_module, model_string):
101101
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
102102
set_layer(new_module, n, new_fc)
103103
if hasattr(new_module, 'num_features'):
104+
if getattr(new_module, 'head_hidden_size', 0) == new_module.num_features:
105+
new_module.head_hidden_size = num_features
104106
new_module.num_features = num_features
107+
105108
new_module.eval()
106109
parent_module.eval()
107110

timm/models/beit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ def __init__(
291291
super().__init__()
292292
self.num_classes = num_classes
293293
self.global_pool = global_pool
294-
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
294+
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
295295
self.num_prefix_tokens = 1
296296
self.grad_checkpointing = False
297297

@@ -392,7 +392,7 @@ def group_matcher(self, coarse=False):
392392
return matcher
393393

394394
@torch.jit.ignore
395-
def get_classifier(self):
395+
def get_classifier(self) -> nn.Module:
396396
return self.head
397397

398398
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):

timm/models/byobnet.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1224,6 +1224,7 @@ def __init__(
12241224
dict(num_chs=self.num_features, reduction=reduction, module='final_conv', stage=len(self.stages))]
12251225
self.stage_ends = [f['stage'] for f in self.feature_info]
12261226

1227+
self.head_hidden_size = self.num_features
12271228
self.head = ClassifierHead(
12281229
self.num_features,
12291230
num_classes,
@@ -1250,10 +1251,11 @@ def set_grad_checkpointing(self, enable=True):
12501251
self.grad_checkpointing = enable
12511252

12521253
@torch.jit.ignore
1253-
def get_classifier(self):
1254+
def get_classifier(self) -> nn.Module:
12541255
return self.head.fc
12551256

1256-
def reset_classifier(self, num_classes, global_pool='avg'):
1257+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
1258+
self.num_classes = num_classes
12571259
self.head.reset(num_classes, global_pool)
12581260

12591261
def forward_intermediates(

timm/models/cait.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,7 @@ def __init__(
239239

240240
self.num_classes = num_classes
241241
self.global_pool = global_pool
242-
self.num_features = self.embed_dim = embed_dim
242+
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim
243243
self.grad_checkpointing = False
244244

245245
self.patch_embed = patch_layer(
@@ -328,7 +328,7 @@ def _matcher(name):
328328
return _matcher
329329

330330
@torch.jit.ignore
331-
def get_classifier(self):
331+
def get_classifier(self) -> nn.Module:
332332
return self.head
333333

334334
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):

timm/models/coat.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
88
Modified from timm/models/vision_transformer.py
99
"""
10-
from typing import List, Optional, Union, Tuple
10+
from typing import List, Optional, Tuple, Union
1111

1212
import torch
1313
import torch.nn as nn
@@ -380,7 +380,7 @@ def __init__(
380380
self.return_interm_layers = return_interm_layers
381381
self.out_features = out_features
382382
self.embed_dims = embed_dims
383-
self.num_features = embed_dims[-1]
383+
self.num_features = self.head_hidden_size = embed_dims[-1]
384384
self.num_classes = num_classes
385385
self.global_pool = global_pool
386386

@@ -556,7 +556,7 @@ def group_matcher(self, coarse=False):
556556
return matcher
557557

558558
@torch.jit.ignore
559-
def get_classifier(self):
559+
def get_classifier(self) -> nn.Module:
560560
return self.head
561561

562562
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):

timm/models/convit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def __init__(
269269
self.num_classes = num_classes
270270
self.global_pool = global_pool
271271
self.local_up_to_layer = local_up_to_layer
272-
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
272+
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
273273
self.locality_strength = locality_strength
274274
self.use_pos_embed = use_pos_embed
275275

@@ -345,7 +345,7 @@ def set_grad_checkpointing(self, enable=True):
345345
assert not enable, 'gradient checkpointing not supported'
346346

347347
@torch.jit.ignore
348-
def get_classifier(self):
348+
def get_classifier(self) -> nn.Module:
349349
return self.head
350350

351351
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):

timm/models/convmixer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
):
4141
super().__init__()
4242
self.num_classes = num_classes
43-
self.num_features = dim
43+
self.num_features = self.head_hidden_size = dim
4444
self.grad_checkpointing = False
4545

4646
self.stem = nn.Sequential(
@@ -74,7 +74,7 @@ def set_grad_checkpointing(self, enable=True):
7474
self.grad_checkpointing = enable
7575

7676
@torch.jit.ignore
77-
def get_classifier(self):
77+
def get_classifier(self) -> nn.Module:
7878
return self.head
7979

8080
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):

timm/models/convnext.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ def __init__(
358358
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
359359
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
360360
self.stages = nn.Sequential(*stages)
361-
self.num_features = prev_chs
361+
self.num_features = self.head_hidden_size = prev_chs
362362

363363
# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
364364
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
@@ -382,6 +382,7 @@ def __init__(
382382
norm_layer=norm_layer,
383383
act_layer='gelu',
384384
)
385+
self.head_hidden_size = self.head.num_features
385386
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)
386387

387388
@torch.jit.ignore
@@ -401,10 +402,11 @@ def set_grad_checkpointing(self, enable=True):
401402
s.grad_checkpointing = enable
402403

403404
@torch.jit.ignore
404-
def get_classifier(self):
405+
def get_classifier(self) -> nn.Module:
405406
return self.head.fc
406407

407-
def reset_classifier(self, num_classes=0, global_pool=None):
408+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
409+
self.num_classes = num_classes
408410
self.head.reset(num_classes, global_pool)
409411

410412
def forward_intermediates(

timm/models/crossvit.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def __init__(
330330
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
331331
self.num_branches = len(patch_size)
332332
self.embed_dim = embed_dim
333-
self.num_features = sum(embed_dim)
333+
self.num_features = self.head_hidden_size = sum(embed_dim)
334334
self.patch_embed = nn.ModuleList()
335335

336336
# hard-coded for torch jit script
@@ -415,17 +415,18 @@ def set_grad_checkpointing(self, enable=True):
415415
assert not enable, 'gradient checkpointing not supported'
416416

417417
@torch.jit.ignore
418-
def get_classifier(self):
418+
def get_classifier(self) -> nn.Module:
419419
return self.head
420420

421421
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
422422
self.num_classes = num_classes
423423
if global_pool is not None:
424424
assert global_pool in ('token', 'avg')
425425
self.global_pool = global_pool
426-
self.head = nn.ModuleList(
427-
[nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
428-
range(self.num_branches)])
426+
self.head = nn.ModuleList([
427+
nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
428+
for i in range(self.num_branches)
429+
])
429430

430431
def forward_features(self, x) -> List[torch.Tensor]:
431432
B = x.shape[0]

timm/models/cspnet.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -675,9 +675,13 @@ def __init__(
675675
self.feature_info.extend(stage_feat_info)
676676

677677
# Construct the head
678-
self.num_features = prev_chs
678+
self.num_features = self.head_hidden_size = prev_chs
679679
self.head = ClassifierHead(
680-
in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
680+
in_features=prev_chs,
681+
num_classes=num_classes,
682+
pool_type=global_pool,
683+
drop_rate=drop_rate,
684+
)
681685

682686
named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)
683687

@@ -698,11 +702,12 @@ def set_grad_checkpointing(self, enable=True):
698702
assert not enable, 'gradient checkpointing not supported'
699703

700704
@torch.jit.ignore
701-
def get_classifier(self):
705+
def get_classifier(self) -> nn.Module:
702706
return self.head.fc
703707

704-
def reset_classifier(self, num_classes, global_pool='avg'):
705-
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
708+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
709+
self.num_classes = num_classes
710+
self.head.reset(num_classes, global_pool)
706711

707712
def forward_features(self, x):
708713
x = self.stem(x)

timm/models/davit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ def __init__(
485485
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
486486
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
487487
self.num_classes = num_classes
488-
self.num_features = embed_dims[-1]
488+
self.num_features = self.head_hidden_size = embed_dims[-1]
489489
self.drop_rate = drop_rate
490490
self.grad_checkpointing = False
491491
self.feature_info = []
@@ -565,7 +565,7 @@ def set_grad_checkpointing(self, enable=True):
565565
stage.set_grad_checkpointing(enable=enable)
566566

567567
@torch.jit.ignore
568-
def get_classifier(self):
568+
def get_classifier(self) -> nn.Module:
569569
return self.head.fc
570570

571571
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):

timm/models/deit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def group_matcher(self, coarse=False):
6060
)
6161

6262
@torch.jit.ignore
63-
def get_classifier(self):
63+
def get_classifier(self) -> nn.Module:
6464
return self.head, self.head_dist
6565

6666
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):

0 commit comments

Comments
 (0)