Skip to content

Commit cb57a96

Browse files
committed
Fix early stop for efficientnet/mobilenetv3 fwd inter. Fix indices typing for all fwd inter.
1 parent 01dd01b commit cb57a96

14 files changed

+21
-24
lines changed

timm/models/convnext.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ def reset_classifier(self, num_classes=0, global_pool=None):
411411
def forward_intermediates(
412412
self,
413413
x: torch.Tensor,
414-
indices: Union[int, List[int], Tuple[int]] = None,
414+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
415415
norm: bool = False,
416416
stop_early: bool = False,
417417
output_fmt: str = 'NCHW',

timm/models/efficientformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
1313
Modifications and timm support by / Copyright 2022, Ross Wightman
1414
"""
15-
from typing import Dict, List, Tuple, Union
15+
from typing import Dict, List, Optional, Tuple, Union
1616

1717
import torch
1818
import torch.nn as nn
@@ -463,7 +463,7 @@ def set_distilled_training(self, enable=True):
463463
def forward_intermediates(
464464
self,
465465
x: torch.Tensor,
466-
indices: Union[int, List[int], Tuple[int]] = None,
466+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
467467
norm: bool = False,
468468
stop_early: bool = False,
469469
output_fmt: str = 'NCHW',

timm/models/efficientnet.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def reset_classifier(self, num_classes, global_pool='avg'):
162162
def forward_intermediates(
163163
self,
164164
x: torch.Tensor,
165-
indices: Union[int, List[int], Tuple[int]] = None,
165+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
166166
norm: bool = False,
167167
stop_early: bool = False,
168168
output_fmt: str = 'NCHW',
@@ -183,8 +183,6 @@ def forward_intermediates(
183183
184184
"""
185185
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
186-
if stop_early:
187-
assert intermediates_only, 'Must use intermediates_only for early stopping.'
188186
intermediates = []
189187
if extra_blocks:
190188
take_indices, max_index = feature_take_indices(len(self.blocks) + 1, indices)
@@ -212,8 +210,9 @@ def forward_intermediates(
212210
if intermediates_only:
213211
return intermediates
214212

215-
x = self.conv_head(x)
216-
x = self.bn2(x)
213+
if feat_idx == self.stage_ends[-1]:
214+
x = self.conv_head(x)
215+
x = self.bn2(x)
217216

218217
return x, intermediates
219218

timm/models/levit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
# Copyright 2020 Ross Wightman, Apache-2.0 License
2626
from collections import OrderedDict
2727
from functools import partial
28-
from typing import Dict, List, Tuple, Union
28+
from typing import Dict, List, Optional, Tuple, Union
2929

3030
import torch
3131
import torch.nn as nn
@@ -638,7 +638,7 @@ def reset_classifier(self, num_classes, global_pool=None, distillation=None):
638638
def forward_intermediates(
639639
self,
640640
x: torch.Tensor,
641-
indices: Union[int, List[int], Tuple[int]] = None,
641+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
642642
norm: bool = False,
643643
stop_early: bool = False,
644644
output_fmt: str = 'NCHW',

timm/models/maxxvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1255,7 +1255,7 @@ def reset_classifier(self, num_classes, global_pool=None):
12551255
def forward_intermediates(
12561256
self,
12571257
x: torch.Tensor,
1258-
indices: Union[int, List[int], Tuple[int]] = None,
1258+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
12591259
norm: bool = False,
12601260
stop_early: bool = False,
12611261
output_fmt: str = 'NCHW',

timm/models/mobilenetv3.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def reset_classifier(self, num_classes: int, global_pool: str = 'avg'):
154154
def forward_intermediates(
155155
self,
156156
x: torch.Tensor,
157-
indices: Union[int, List[int], Tuple[int]] = None,
157+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
158158
norm: bool = False,
159159
stop_early: bool = False,
160160
output_fmt: str = 'NCHW',

timm/models/mvitv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -837,7 +837,7 @@ def reset_classifier(self, num_classes, global_pool=None):
837837
def forward_intermediates(
838838
self,
839839
x: torch.Tensor,
840-
indices: Union[int, List[int], Tuple[int]] = None,
840+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
841841
norm: bool = False,
842842
stop_early: bool = False,
843843
output_fmt: str = 'NCHW',

timm/models/resnet.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,7 @@ def reset_classifier(self, num_classes, global_pool='avg'):
557557
def forward_intermediates(
558558
self,
559559
x: torch.Tensor,
560-
indices: Union[int, List[int], Tuple[int]] = None,
560+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
561561
norm: bool = False,
562562
stop_early: bool = False,
563563
output_fmt: str = 'NCHW',
@@ -576,8 +576,6 @@ def forward_intermediates(
576576
577577
"""
578578
assert output_fmt in ('NCHW',), 'Output shape must be NCHW.'
579-
if stop_early:
580-
assert intermediates_only, 'Must use intermediates_only for early stopping.'
581579
intermediates = []
582580
take_indices, max_index = feature_take_indices(5, indices)
583581

timm/models/swin_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ def reset_classifier(self, num_classes, global_pool=None):
611611
def forward_intermediates(
612612
self,
613613
x: torch.Tensor,
614-
indices: Union[int, List[int], Tuple[int]] = None,
614+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
615615
norm: bool = False,
616616
stop_early: bool = False,
617617
output_fmt: str = 'NCHW',

timm/models/swin_transformer_v2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,7 @@ def reset_classifier(self, num_classes, global_pool=None):
612612
def forward_intermediates(
613613
self,
614614
x: torch.Tensor,
615-
indices: Union[int, List[int], Tuple[int]] = None,
615+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
616616
norm: bool = False,
617617
stop_early: bool = False,
618618
output_fmt: str = 'NCHW',

timm/models/swin_transformer_v2_cr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None)
722722
def forward_intermediates(
723723
self,
724724
x: torch.Tensor,
725-
indices: Union[int, List[int], Tuple[int]] = None,
725+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
726726
norm: bool = False,
727727
stop_early: bool = False,
728728
output_fmt: str = 'NCHW',

timm/models/twins.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -407,7 +407,7 @@ def _init_weights(self, m):
407407
def forward_intermediates(
408408
self,
409409
x: torch.Tensor,
410-
indices: Union[int, List[int], Tuple[int]] = None,
410+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
411411
norm: bool = False,
412412
stop_early: bool = False,
413413
output_fmt: str = 'NCHW',

timm/models/vision_transformer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def __init__(
489489
**embed_args,
490490
)
491491
num_patches = self.patch_embed.num_patches
492-
r = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
492+
reduction = self.patch_embed.feat_ratio() if hasattr(self.patch_embed, 'feat_ratio') else patch_size
493493

494494
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
495495
self.reg_token = nn.Parameter(torch.zeros(1, reg_tokens, embed_dim)) if reg_tokens else None
@@ -523,7 +523,7 @@ def __init__(
523523
)
524524
for i in range(depth)])
525525
self.feature_info = [
526-
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=r) for i in range(depth)]
526+
dict(module=f'blocks.{i}', num_chs=embed_dim, reduction=reduction) for i in range(depth)]
527527
self.norm = norm_layer(embed_dim) if not use_fc_norm else nn.Identity()
528528

529529
# Classifier Head

timm/models/vision_transformer_sam.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -543,7 +543,7 @@ def reset_classifier(self, num_classes=0, global_pool=None):
543543
def forward_intermediates(
544544
self,
545545
x: torch.Tensor,
546-
indices: Union[int, List[int], Tuple[int]] = None,
546+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
547547
norm: bool = False,
548548
stop_early: bool = False,
549549
output_fmt: str = 'NCHW',
@@ -598,7 +598,7 @@ def forward_intermediates(
598598

599599
def prune_intermediate_layers(
600600
self,
601-
indices: Union[int, List[int], Tuple[int]] = None,
601+
indices: Optional[Union[int, List[int], Tuple[int]]] = None,
602602
prune_norm: bool = False,
603603
prune_head: bool = True,
604604
):

0 commit comments

Comments
 (0)