Skip to content

Commit 211d18d

Browse files
committed
Move norm & pool into Hiera ClassifierHead. Misc fixes, update features_intermediate() naming
1 parent 2ca45a4 commit 211d18d

File tree

3 files changed

+53
-31
lines changed

3 files changed

+53
-31
lines changed

timm/layers/classifier.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __init__(
108108
self.fc = fc
109109
self.flatten = nn.Flatten(1) if use_conv and pool_type else nn.Identity()
110110

111-
def reset(self, num_classes, pool_type=None):
111+
def reset(self, num_classes: int, pool_type: Optional[str] = None):
112112
if pool_type is not None and pool_type != self.global_pool.pool_type:
113113
self.global_pool, self.fc = create_classifier(
114114
self.in_features,
@@ -180,7 +180,7 @@ def __init__(
180180
self.drop = nn.Dropout(drop_rate)
181181
self.fc = linear_layer(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
182182

183-
def reset(self, num_classes, pool_type=None):
183+
def reset(self, num_classes: int, pool_type: Optional[str] = None):
184184
if pool_type is not None:
185185
self.global_pool = SelectAdaptivePool2d(pool_type=pool_type)
186186
self.flatten = nn.Flatten(1) if pool_type else nn.Identity()

timm/layers/create_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def get_norm_layer(norm_layer):
4747
if isinstance(norm_layer, str):
4848
if not norm_layer:
4949
return None
50-
layer_name = norm_layer.replace('_', '')
50+
layer_name = norm_layer.replace('_', '').lower()
5151
norm_layer = _NORM_MAP[layer_name]
5252
else:
5353
norm_layer = norm_layer

timm/models/hiera.py

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333

3434
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
35-
from timm.layers import DropPath, Mlp, use_fused_attn, _assert
35+
from timm.layers import DropPath, Mlp, use_fused_attn, _assert, get_norm_layer
3636

3737

3838
from ._registry import generate_default_cfgs, register_model
@@ -372,20 +372,41 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
372372
return x
373373

374374

375-
class Head(nn.Module):
375+
class NormClassifierHead(nn.Module):
376376
def __init__(
377377
self,
378-
dim: int,
378+
in_features: int,
379379
num_classes: int,
380+
pool_type: str = 'avg',
380381
drop_rate: float = 0.0,
382+
norm_layer: Union[str, Callable] = 'layernorm',
381383
):
382384
super().__init__()
383-
self.dropout = nn.Dropout(drop_rate) if drop_rate > 0 else nn.Identity()
384-
self.projection = nn.Linear(dim, num_classes)
385+
norm_layer = get_norm_layer(norm_layer)
386+
assert pool_type in ('avg', '')
387+
self.in_features = self.num_features = in_features
388+
self.pool_type = pool_type
389+
self.norm = norm_layer(in_features)
390+
self.drop = nn.Dropout(drop_rate) if drop_rate else nn.Identity()
391+
self.fc = nn.Linear(in_features, num_classes) if num_classes > 0 else nn.Identity()
392+
393+
def reset(self, num_classes: int, pool_type: Optional[str] = None, other: bool = False):
394+
if pool_type is not None:
395+
assert pool_type in ('avg', '')
396+
self.pool_type = pool_type
397+
if other:
398+
# reset other non-fc layers
399+
self.norm = nn.Identity()
400+
self.fc = nn.Linear(self.in_features, num_classes) if num_classes > 0 else nn.Identity()
385401

386-
def forward(self, x: torch.Tensor) -> torch.Tensor:
387-
x = self.dropout(x)
388-
x = self.projection(x)
402+
def forward(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tensor:
403+
if self.pool_type == 'avg':
404+
x = x.mean(dim=1)
405+
x = self.norm(x)
406+
x = self.drop(x)
407+
if pre_logits:
408+
return x
409+
x = self.fc(x)
389410
return x
390411

391412

@@ -438,6 +459,7 @@ def __init__(
438459
embed_dim: int = 96, # initial embed dim
439460
num_heads: int = 1, # initial number of heads
440461
num_classes: int = 1000,
462+
global_pool: str = 'avg',
441463
stages: Tuple[int, ...] = (2, 3, 16, 3),
442464
q_pool: int = 3, # number of q_pool stages
443465
q_stride: Tuple[int, ...] = (2, 2),
@@ -458,11 +480,7 @@ def __init__(
458480
):
459481
super().__init__()
460482
self.num_classes = num_classes
461-
462-
# Do it this way to ensure that the init args are all PoD (for config usage)
463-
if isinstance(norm_layer, str):
464-
norm_layer = partial(getattr(nn, norm_layer), eps=1e-6)
465-
483+
norm_layer = get_norm_layer(norm_layer)
466484
depth = sum(stages)
467485
self.patch_stride = patch_stride
468486
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
@@ -552,8 +570,14 @@ def __init__(
552570
dict(num_chs=dim_out, reduction=2**(cur_stage+2), module=f'blocks.{self.stage_ends[cur_stage]}')]
553571
self.blocks.append(block)
554572

555-
self.norm = norm_layer(embed_dim)
556-
self.head = Head(embed_dim, num_classes, drop_rate=drop_rate)
573+
self.num_features = embed_dim
574+
self.head = NormClassifierHead(
575+
embed_dim,
576+
num_classes,
577+
pool_type=global_pool,
578+
drop_rate=drop_rate,
579+
norm_layer=norm_layer,
580+
)
557581

558582
# Initialize everything
559583
if sep_pos_embed:
@@ -562,8 +586,8 @@ def __init__(
562586
else:
563587
nn.init.trunc_normal_(self.pos_embed, std=0.02)
564588
self.apply(partial(self._init_weights))
565-
self.head.projection.weight.data.mul_(head_init_scale)
566-
self.head.projection.bias.data.mul_(head_init_scale)
589+
self.head.fc.weight.data.mul_(head_init_scale)
590+
self.head.fc.bias.data.mul_(head_init_scale)
567591

568592
def _init_weights(self, m, init_bias=0.02):
569593
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
@@ -678,19 +702,17 @@ def forward_intermediates(
678702

679703
def prune_intermediate_layers(
680704
self,
681-
n: Union[int, List[int], Tuple[int]] = 1,
705+
indices: Union[int, List[int], Tuple[int]] = 1,
682706
prune_norm: bool = False,
683707
prune_head: bool = True,
684708
):
685709
""" Prune layers not required for specified intermediates.
686710
"""
687-
take_indices, max_index = feature_take_indices(len(self.stage_ends), n)
711+
take_indices, max_index = feature_take_indices(len(self.stage_ends), indices)
688712
max_index = self.stage_ends[max_index]
689713
self.blocks = self.blocks[:max_index + 1] # truncate blocks
690714
if prune_head:
691-
# norm part of head for this model, equivalent to fc_norm in other vit.
692-
self.norm = nn.Identity()
693-
self.head = nn.Identity()
715+
self.head.reset(0, other=True)
694716
return take_indices
695717

696718

@@ -732,11 +754,7 @@ def forward_features(
732754
return x
733755

734756
def forward_head(self, x, pre_logits: bool = False) -> torch.Tensor:
735-
x = x.mean(dim=1)
736-
x = self.norm(x)
737-
if pre_logits:
738-
return x
739-
x = self.head(x)
757+
x = self.head(x, pre_logits=pre_logits) if pre_logits else self.head(x)
740758
return x
741759

742760
def forward(
@@ -756,7 +774,7 @@ def _cfg(url='', **kwargs):
756774
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
757775
'crop_pct': .9, 'interpolation': 'bicubic', 'fixed_input_size': True,
758776
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD,
759-
'first_conv': 'patch_embed.proj', 'classifier': 'head',
777+
'first_conv': 'patch_embed.proj', 'classifier': 'head.fc',
760778
**kwargs
761779
}
762780

@@ -837,6 +855,10 @@ def checkpoint_filter_fn(state_dict, model=None):
837855
# )
838856
#v = F.interpolate(v.transpose(1, 2), (model.pos_embed.shape[1],)).transpose(1, 2)
839857
pass
858+
if 'head.projection.' in k:
859+
k = k.replace('head.projection.', 'head.fc.')
860+
if k.startswith('norm.'):
861+
k = k.replace('norm.', 'head.norm.')
840862
output[k] = v
841863
return output
842864

0 commit comments

Comments
 (0)