Skip to content

Fix consistency, testing for forward_head w/ pre_logits, reset_classifier, models with pre_logits size != unpooled feature size #2195

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 39 additions & 18 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,18 @@ def test_model_backward(model_name, batch_size):
assert not torch.isnan(outputs).any(), 'Output included NaNs'


# models with extra conv/linear layers after pooling
EARLY_POOL_MODELS = (
timm.models.EfficientVit,
timm.models.EfficientVitLarge,
timm.models.HighPerfGpuNet,
timm.models.GhostNet,
timm.models.MetaNeXt, # InceptionNeXt
timm.models.MobileNetV3,
timm.models.RepGhostNet,
timm.models.VGG,
)

@pytest.mark.cfg
@pytest.mark.timeout(timeout300)
@pytest.mark.parametrize('model_name', list_models(
Expand All @@ -179,6 +191,9 @@ def test_model_default_cfgs(model_name, batch_size):
model = create_model(model_name, pretrained=False)
model.eval()
model.to(torch_device)
assert getattr(model, 'num_classes') >= 0
assert getattr(model, 'num_features') > 0
assert getattr(model, 'head_hidden_size') > 0
state_dict = model.state_dict()
cfg = model.default_cfg

Expand All @@ -195,37 +210,37 @@ def test_model_default_cfgs(model_name, batch_size):
input_size = tuple([min(x, MAX_FWD_OUT_SIZE) for x in input_size])
input_tensor = torch.randn((batch_size, *input_size), device=torch_device)

# test forward_features (always unpooled)
# test forward_features (always unpooled) & forward_head w/ pre_logits
outputs = model.forward_features(input_tensor)
assert outputs.shape[spatial_axis[0]] == pool_size[0], 'unpooled feature shape != config'
assert outputs.shape[spatial_axis[1]] == pool_size[1], 'unpooled feature shape != config'
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
assert outputs.shape[feat_axis] == model.num_features
outputs_pre = model.forward_head(outputs, pre_logits=True)
assert outputs.shape[spatial_axis[0]] == pool_size[0], f'unpooled feature shape {outputs.shape} != config'
assert outputs.shape[spatial_axis[1]] == pool_size[1], f'unpooled feature shape {outputs.shape} != config'
assert outputs.shape[feat_axis] == model.num_features, f'unpooled feature dim {outputs.shape[feat_axis]} != model.num_features {model.num_features}'
assert outputs_pre.shape[1] == model.head_hidden_size, f'pre_logits feature dim {outputs_pre.shape[1]} != model.head_hidden_size {model.head_hidden_size}'

# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
model.to(torch_device)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 2
assert outputs.shape[1] == model.num_features
assert outputs.shape[1] == model.head_hidden_size, f'feature dim w/ removed classifier {outputs.shape[1]} != model.head_hidden_size {model.head_hidden_size}'
assert outputs.shape == outputs_pre.shape, f'output shape of pre_logits {outputs_pre.shape} does not match reset_head(0) {outputs.shape}'

# test model forward without pooling and classifier
model.reset_classifier(0, '') # reset classifier and set global pooling to pass-through
model.to(torch_device)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
# mobilenetv3/ghostnet/repghostnet/vgg forward_features vs removed pooling differ due to location or lack of GAP
# test model forward after removing pooling and classifier
if not isinstance(model, EARLY_POOL_MODELS):
model.reset_classifier(0, '') # reset classifier and disable global pooling
model.to(torch_device)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]

if 'pruned' not in model_name: # FIXME better pruned model handling
# test classifier + global pool deletion via __init__
# test classifier + global pool deletion via __init__
if 'pruned' not in model_name and not isinstance(model, EARLY_POOL_MODELS):
model = create_model(model_name, pretrained=False, num_classes=0, global_pool='').eval()
model.to(torch_device)
outputs = model.forward(input_tensor)
assert len(outputs.shape) == 4
if not isinstance(model, (timm.models.MobileNetV3, timm.models.GhostNet, timm.models.RepGhostNet, timm.models.VGG)):
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]
assert outputs.shape[spatial_axis[0]] == pool_size[0] and outputs.shape[spatial_axis[1]] == pool_size[1]

# check classifier name matches default_cfg
if cfg.get('num_classes', None):
Expand Down Expand Up @@ -253,6 +268,9 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
model = create_model(model_name, pretrained=False)
model.eval()
model.to(torch_device)
assert getattr(model, 'num_classes') >= 0
assert getattr(model, 'num_features') > 0
assert getattr(model, 'head_hidden_size') > 0
state_dict = model.state_dict()
cfg = model.default_cfg

Expand All @@ -264,13 +282,15 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
feat_dim = getattr(model, 'feature_dim', None)

outputs = model.forward_features(input_tensor)
outputs_pre = model.forward_head(outputs, pre_logits=True)
if isinstance(outputs, (tuple, list)):
# cannot currently verify multi-tensor output.
pass
else:
if feat_dim is None:
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features
assert outputs_pre.shape[1] == model.head_hidden_size

# test forward after deleting the classifier, output should be poooled, size(-1) == model.num_features
model.reset_classifier(0)
Expand All @@ -280,7 +300,8 @@ def test_model_default_cfgs_non_std(model_name, batch_size):
outputs = outputs[0]
if feat_dim is None:
feat_dim = -1 if outputs.ndim == 3 else 1
assert outputs.shape[feat_dim] == model.num_features, 'pooled num_features != config'
assert outputs.shape[feat_dim] == model.head_hidden_size, 'pooled num_features != config'
assert outputs.shape == outputs_pre.shape

model = create_model(model_name, pretrained=False, num_classes=0).eval()
model.to(torch_device)
Expand Down
3 changes: 3 additions & 0 deletions timm/models/_prune.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,10 @@ def adapt_model_from_string(parent_module, model_string):
in_features=num_features, out_features=old_module.out_features, bias=old_module.bias is not None)
set_layer(new_module, n, new_fc)
if hasattr(new_module, 'num_features'):
if getattr(new_module, 'head_hidden_size', 0) == new_module.num_features:
new_module.head_hidden_size = num_features
new_module.num_features = num_features

new_module.eval()
parent_module.eval()

Expand Down
4 changes: 2 additions & 2 deletions timm/models/beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def __init__(
super().__init__()
self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
self.num_prefix_tokens = 1
self.grad_checkpointing = False

Expand Down Expand Up @@ -392,7 +392,7 @@ def group_matcher(self, coarse=False):
return matcher

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
6 changes: 4 additions & 2 deletions timm/models/byobnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,7 @@ def __init__(
dict(num_chs=self.num_features, reduction=reduction, module='final_conv', stage=len(self.stages))]
self.stage_ends = [f['stage'] for f in self.feature_info]

self.head_hidden_size = self.num_features
self.head = ClassifierHead(
self.num_features,
num_classes,
Expand All @@ -1250,10 +1251,11 @@ def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

def forward_intermediates(
Expand Down
4 changes: 2 additions & 2 deletions timm/models/cait.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ def __init__(

self.num_classes = num_classes
self.global_pool = global_pool
self.num_features = self.embed_dim = embed_dim
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim
self.grad_checkpointing = False

self.patch_embed = patch_layer(
Expand Down Expand Up @@ -328,7 +328,7 @@ def _matcher(name):
return _matcher

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
6 changes: 3 additions & 3 deletions timm/models/coat.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

Modified from timm/models/vision_transformer.py
"""
from typing import List, Optional, Union, Tuple
from typing import List, Optional, Tuple, Union

import torch
import torch.nn as nn
Expand Down Expand Up @@ -380,7 +380,7 @@ def __init__(
self.return_interm_layers = return_interm_layers
self.out_features = out_features
self.embed_dims = embed_dims
self.num_features = embed_dims[-1]
self.num_features = self.head_hidden_size = embed_dims[-1]
self.num_classes = num_classes
self.global_pool = global_pool

Expand Down Expand Up @@ -556,7 +556,7 @@ def group_matcher(self, coarse=False):
return matcher

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
4 changes: 2 additions & 2 deletions timm/models/convit.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def __init__(
self.num_classes = num_classes
self.global_pool = global_pool
self.local_up_to_layer = local_up_to_layer
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_features = self.head_hidden_size = self.embed_dim = embed_dim # for consistency with other models
self.locality_strength = locality_strength
self.use_pos_embed = use_pos_embed

Expand Down Expand Up @@ -345,7 +345,7 @@ def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
4 changes: 2 additions & 2 deletions timm/models/convmixer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
):
super().__init__()
self.num_classes = num_classes
self.num_features = dim
self.num_features = self.head_hidden_size = dim
self.grad_checkpointing = False

self.stem = nn.Sequential(
Expand Down Expand Up @@ -74,7 +74,7 @@ def set_grad_checkpointing(self, enable=True):
self.grad_checkpointing = enable

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
8 changes: 5 additions & 3 deletions timm/models/convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __init__(
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
self.feature_info += [dict(num_chs=prev_chs, reduction=curr_stride, module=f'stages.{i}')]
self.stages = nn.Sequential(*stages)
self.num_features = prev_chs
self.num_features = self.head_hidden_size = prev_chs

# if head_norm_first == true, norm -> global pool -> fc ordering, like most other nets
# otherwise pool -> norm -> fc, the default ConvNeXt ordering (pretrained FB weights)
Expand All @@ -382,6 +382,7 @@ def __init__(
norm_layer=norm_layer,
act_layer='gelu',
)
self.head_hidden_size = self.head.num_features
named_apply(partial(_init_weights, head_init_scale=head_init_scale), self)

@torch.jit.ignore
Expand All @@ -401,10 +402,11 @@ def set_grad_checkpointing(self, enable=True):
s.grad_checkpointing = enable

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes=0, global_pool=None):
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

def forward_intermediates(
Expand Down
11 changes: 6 additions & 5 deletions timm/models/crossvit.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def __init__(
num_patches = _compute_num_patches(self.img_size_scaled, patch_size)
self.num_branches = len(patch_size)
self.embed_dim = embed_dim
self.num_features = sum(embed_dim)
self.num_features = self.head_hidden_size = sum(embed_dim)
self.patch_embed = nn.ModuleList()

# hard-coded for torch jit script
Expand Down Expand Up @@ -415,17 +415,18 @@ def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('token', 'avg')
self.global_pool = global_pool
self.head = nn.ModuleList(
[nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity() for i in
range(self.num_branches)])
self.head = nn.ModuleList([
nn.Linear(self.embed_dim[i], num_classes) if num_classes > 0 else nn.Identity()
for i in range(self.num_branches)
])

def forward_features(self, x) -> List[torch.Tensor]:
B = x.shape[0]
Expand Down
15 changes: 10 additions & 5 deletions timm/models/cspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -675,9 +675,13 @@ def __init__(
self.feature_info.extend(stage_feat_info)

# Construct the head
self.num_features = prev_chs
self.num_features = self.head_hidden_size = prev_chs
self.head = ClassifierHead(
in_features=prev_chs, num_classes=num_classes, pool_type=global_pool, drop_rate=drop_rate)
in_features=prev_chs,
num_classes=num_classes,
pool_type=global_pool,
drop_rate=drop_rate,
)

named_apply(partial(_init_weights, zero_init_last=zero_init_last), self)

Expand All @@ -698,11 +702,12 @@ def set_grad_checkpointing(self, enable=True):
assert not enable, 'gradient checkpointing not supported'

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes, global_pool='avg'):
self.head = ClassifierHead(self.num_features, num_classes, pool_type=global_pool, drop_rate=self.drop_rate)
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
self.num_classes = num_classes
self.head.reset(num_classes, global_pool)

def forward_features(self, x):
x = self.stem(x)
Expand Down
4 changes: 2 additions & 2 deletions timm/models/davit.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def __init__(
norm_layer = partial(get_norm_layer(norm_layer), eps=norm_eps)
norm_layer_cl = partial(get_norm_layer(norm_layer_cl), eps=norm_eps)
self.num_classes = num_classes
self.num_features = embed_dims[-1]
self.num_features = self.head_hidden_size = embed_dims[-1]
self.drop_rate = drop_rate
self.grad_checkpointing = False
self.feature_info = []
Expand Down Expand Up @@ -565,7 +565,7 @@ def set_grad_checkpointing(self, enable=True):
stage.set_grad_checkpointing(enable=enable)

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head.fc

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
2 changes: 1 addition & 1 deletion timm/models/deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def group_matcher(self, coarse=False):
)

@torch.jit.ignore
def get_classifier(self):
def get_classifier(self) -> nn.Module:
return self.head, self.head_dist

def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
Expand Down
Loading
Loading