Skip to content

Commit 3e03b2b

Browse files
committed
Fix a few more hiera API issues
1 parent 211d18d commit 3e03b2b

File tree

1 file changed

+32
-9
lines changed

1 file changed

+32
-9
lines changed

timm/models/hiera.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@
2424
# --------------------------------------------------------
2525
import math
2626
from functools import partial
27-
from typing import List, Tuple, Type, Callable, Optional, Union
27+
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
2828

2929
import torch
3030
import torch.nn as nn
3131
import torch.nn.functional as F
32+
from torch.utils.checkpoint import checkpoint
3233

3334

3435
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
@@ -480,14 +481,14 @@ def __init__(
480481
):
481482
super().__init__()
482483
self.num_classes = num_classes
484+
self.grad_checkpointing = False
483485
norm_layer = get_norm_layer(norm_layer)
484-
depth = sum(stages)
486+
485487
self.patch_stride = patch_stride
486488
self.tokens_spatial_shape = [i // s for i, s in zip(img_size, patch_stride)]
487489
num_tokens = math.prod(self.tokens_spatial_shape)
488490
flat_mu_size = math.prod(mask_unit_size)
489491
flat_q_stride = math.prod(q_stride)
490-
491492
assert q_pool < len(stages)
492493
self.q_pool, self.q_stride = q_pool, q_stride
493494
self.mu_size, self.mask_unit_size = flat_mu_size, mask_unit_size
@@ -532,11 +533,10 @@ def __init__(
532533
# q_pool locations
533534
q_pool_blocks = [x + 1 for x in self.stage_ends[:q_pool]]
534535

535-
# stochastic depth decay rule
536-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
537-
538536
# Transformer blocks
539537
cur_stage = 0
538+
depth = sum(stages)
539+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
540540
self.blocks = nn.ModuleList()
541541
self.feature_info = []
542542
for i in range(depth):
@@ -586,8 +586,9 @@ def __init__(
586586
else:
587587
nn.init.trunc_normal_(self.pos_embed, std=0.02)
588588
self.apply(partial(self._init_weights))
589-
self.head.fc.weight.data.mul_(head_init_scale)
590-
self.head.fc.bias.data.mul_(head_init_scale)
589+
if isinstance(self.head.fc, nn.Linear):
590+
self.head.fc.weight.data.mul_(head_init_scale)
591+
self.head.fc.bias.data.mul_(head_init_scale)
591592

592593
def _init_weights(self, m, init_bias=0.02):
593594
if isinstance(m, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)):
@@ -605,6 +606,25 @@ def no_weight_decay(self):
605606
else:
606607
return ["pos_embed_spatial", "pos_embed_temporal"]
607608

609+
@torch.jit.ignore
610+
def group_matcher(self, coarse: bool = False) -> Dict:
611+
return dict(
612+
stem=r'^pos_embed|pos_embed_spatial|pos_embed_temporal|patch_embed', # stem and embed
613+
blocks=[(r'^blocks\.(\d+)', None), (r'^norm', (99999,))]
614+
)
615+
616+
@torch.jit.ignore
617+
def set_grad_checkpointing(self, enable: bool = True) -> None:
618+
self.grad_checkpointing = enable
619+
620+
@torch.jit.ignore
621+
def get_classifier(self):
622+
return self.head.fc
623+
624+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, other: bool = False):
625+
self.num_classes = num_classes
626+
self.head.reset(num_classes, global_pool, other=other)
627+
608628
def get_random_mask(self, x: torch.Tensor, mask_ratio: float) -> torch.Tensor:
609629
"""
610630
Generates a random mask, mask_ratio fraction are dropped.
@@ -740,7 +760,10 @@ def forward_features(
740760

741761
intermediates = []
742762
for i, blk in enumerate(self.blocks):
743-
x = blk(x)
763+
if self.grad_checkpointing and not torch.jit.is_scripting():
764+
x = checkpoint(blk, x)
765+
else:
766+
x = blk(x)
744767
if return_intermediates and i in self.stage_ends:
745768
intermediates.append(self.reroll(x, i, mask=mask))
746769

0 commit comments

Comments
 (0)