Skip to content

Commit c838c42

Browse files
committed
Add typing to reset_classifier() on other models
1 parent 3e03b2b commit c838c42

35 files changed

+58
-61
lines changed

timm/models/beit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def group_matcher(self, coarse=False):
395395
def get_classifier(self):
396396
return self.head
397397

398-
def reset_classifier(self, num_classes, global_pool=None):
398+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
399399
self.num_classes = num_classes
400400
if global_pool is not None:
401401
self.global_pool = global_pool

timm/models/cait.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def _matcher(name):
331331
def get_classifier(self):
332332
return self.head
333333

334-
def reset_classifier(self, num_classes, global_pool=None):
334+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
335335
self.num_classes = num_classes
336336
if global_pool is not None:
337337
assert global_pool in ('', 'token', 'avg')

timm/models/coat.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
88
Modified from timm/models/vision_transformer.py
99
"""
10-
from functools import partial
11-
from typing import Tuple, List, Union
10+
from typing import List, Optional, Union, Tuple
1211

1312
import torch
1413
import torch.nn as nn
@@ -560,7 +559,7 @@ def group_matcher(self, coarse=False):
560559
def get_classifier(self):
561560
return self.head
562561

563-
def reset_classifier(self, num_classes, global_pool=None):
562+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
564563
self.num_classes = num_classes
565564
if global_pool is not None:
566565
assert global_pool in ('token', 'avg')

timm/models/convit.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@
2121
'''These modules are adapted from those of timm, see
2222
https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
2323
'''
24-
25-
from functools import partial
24+
from typing import Optional
2625

2726
import torch
2827
import torch.nn as nn
@@ -349,7 +348,7 @@ def set_grad_checkpointing(self, enable=True):
349348
def get_classifier(self):
350349
return self.head
351350

352-
def reset_classifier(self, num_classes, global_pool=None):
351+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
353352
self.num_classes = num_classes
354353
if global_pool is not None:
355354
assert global_pool in ('', 'token', 'avg')

timm/models/convmixer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
""" ConvMixer
22
33
"""
4+
from typing import Optional
5+
46
import torch
57
import torch.nn as nn
68

@@ -75,7 +77,7 @@ def set_grad_checkpointing(self, enable=True):
7577
def get_classifier(self):
7678
return self.head
7779

78-
def reset_classifier(self, num_classes, global_pool=None):
80+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
7981
self.num_classes = num_classes
8082
if global_pool is not None:
8183
self.pooling = SelectAdaptivePool2d(pool_type=global_pool, flatten=True)

timm/models/convnext.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@
3737
# LICENSE file in the root directory of this source tree (Attribution-NonCommercial 4.0 International (CC BY-NC 4.0))
3838
# No code was used directly from ConvNeXt-V2, however the weights are CC BY-NC 4.0 so beware if using commercially.
3939

40-
from collections import OrderedDict
4140
from functools import partial
4241
from typing import Callable, List, Optional, Tuple, Union
4342

timm/models/crossvit.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,7 @@
2525
2626
"""
2727
from functools import partial
28-
from typing import List
29-
from typing import Tuple
28+
from typing import List, Optional, Tuple
3029

3130
import torch
3231
import torch.hub
@@ -419,7 +418,7 @@ def set_grad_checkpointing(self, enable=True):
419418
def get_classifier(self):
420419
return self.head
421420

422-
def reset_classifier(self, num_classes, global_pool=None):
421+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
423422
self.num_classes = num_classes
424423
if global_pool is not None:
425424
assert global_pool in ('token', 'avg')

timm/models/davit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# All rights reserved.
1313
# This source code is licensed under the MIT license
1414
from functools import partial
15-
from typing import Tuple
15+
from typing import Optional, Tuple
1616

1717
import torch
1818
import torch.nn as nn
@@ -568,7 +568,7 @@ def set_grad_checkpointing(self, enable=True):
568568
def get_classifier(self):
569569
return self.head.fc
570570

571-
def reset_classifier(self, num_classes, global_pool=None):
571+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
572572
self.head.reset(num_classes, global_pool)
573573

574574
def forward_features(self, x):

timm/models/deit.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
# Copyright (c) 2015-present, Facebook, Inc.
1212
# All rights reserved.
1313
from functools import partial
14-
from typing import Sequence, Union
14+
from typing import Optional
1515

1616
import torch
1717
from torch import nn as nn
@@ -20,7 +20,6 @@
2020
from timm.layers import resample_abs_pos_embed
2121
from timm.models.vision_transformer import VisionTransformer, trunc_normal_, checkpoint_filter_fn
2222
from ._builder import build_model_with_cfg
23-
from ._manipulate import checkpoint_seq
2423
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
2524

2625
__all__ = ['VisionTransformerDistilled'] # model_registry will add each entrypoint fn to this
@@ -64,7 +63,7 @@ def group_matcher(self, coarse=False):
6463
def get_classifier(self):
6564
return self.head, self.head_dist
6665

67-
def reset_classifier(self, num_classes, global_pool=None):
66+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
6867
self.num_classes = num_classes
6968
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
7069
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()

timm/models/edgenext.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
Modifications and additions for timm by / Copyright 2022, Ross Wightman
99
"""
1010
import math
11-
from collections import OrderedDict
1211
from functools import partial
1312
from typing import Tuple
1413

@@ -17,7 +16,7 @@
1716
from torch import nn
1817

1918
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
20-
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, SelectAdaptivePool2d, create_conv2d, \
19+
from timm.layers import trunc_normal_tf_, DropPath, LayerNorm2d, Mlp, create_conv2d, \
2120
use_fused_attn, NormMlpClassifierHead, ClassifierHead
2221
from ._builder import build_model_with_cfg
2322
from ._features_fx import register_notrace_module

timm/models/efficientformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,7 @@ def set_grad_checkpointing(self, enable=True):
449449
def get_classifier(self):
450450
return self.head, self.head_dist
451451

452-
def reset_classifier(self, num_classes, global_pool=None):
452+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
453453
self.num_classes = num_classes
454454
if global_pool is not None:
455455
self.global_pool = global_pool

timm/models/efficientformer_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717
import math
1818
from functools import partial
19-
from typing import Dict
19+
from typing import Dict, Optional
2020

2121
import torch
2222
import torch.nn as nn
@@ -612,7 +612,7 @@ def set_grad_checkpointing(self, enable=True):
612612
def get_classifier(self):
613613
return self.head, self.head_dist
614614

615-
def reset_classifier(self, num_classes, global_pool=None):
615+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
616616
self.num_classes = num_classes
617617
if global_pool is not None:
618618
self.global_pool = global_pool

timm/models/efficientvit_mit.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import torch
1414
import torch.nn as nn
1515
import torch.nn.functional as F
16-
from torch.nn.modules.batchnorm import _BatchNorm
1716

1817
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
1918
from timm.layers import SelectAdaptivePool2d, create_conv2d, GELUTanh
@@ -740,7 +739,7 @@ def set_grad_checkpointing(self, enable=True):
740739
def get_classifier(self):
741740
return self.head.classifier[-1]
742741

743-
def reset_classifier(self, num_classes, global_pool=None):
742+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
744743
self.num_classes = num_classes
745744
if global_pool is not None:
746745
self.global_pool = global_pool
@@ -858,7 +857,7 @@ def set_grad_checkpointing(self, enable=True):
858857
def get_classifier(self):
859858
return self.head.classifier[-1]
860859

861-
def reset_classifier(self, num_classes, global_pool=None):
860+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
862861
self.num_classes = num_classes
863862
if global_pool is not None:
864863
self.global_pool = global_pool

timm/models/efficientvit_msra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
__all__ = ['EfficientVitMsra']
1010
import itertools
1111
from collections import OrderedDict
12-
from typing import Dict
12+
from typing import Dict, Optional
1313

1414
import torch
1515
import torch.nn as nn
@@ -464,7 +464,7 @@ def set_grad_checkpointing(self, enable=True):
464464
def get_classifier(self):
465465
return self.head.linear
466466

467-
def reset_classifier(self, num_classes, global_pool=None):
467+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
468468
self.num_classes = num_classes
469469
if global_pool is not None:
470470
if global_pool == 'avg':

timm/models/eva.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -539,7 +539,7 @@ def group_matcher(self, coarse=False):
539539
def get_classifier(self):
540540
return self.head
541541

542-
def reset_classifier(self, num_classes, global_pool=None):
542+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
543543
self.num_classes = num_classes
544544
if global_pool is not None:
545545
self.global_pool = global_pool

timm/models/fastvit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,7 @@ def reparameterize(self) -> None:
396396

397397
@staticmethod
398398
def _fuse_bn(
399-
conv: torch.Tensor, bn: nn.BatchNorm2d
399+
conv: nn.Conv2d, bn: nn.BatchNorm2d
400400
) -> Tuple[torch.Tensor, torch.Tensor]:
401401
"""Method to fuse batchnorm layer with conv layer.
402402
@@ -1232,7 +1232,7 @@ def set_grad_checkpointing(self, enable=True):
12321232
def get_classifier(self):
12331233
return self.head.fc
12341234

1235-
def reset_classifier(self, num_classes, global_pool=None):
1235+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
12361236
self.num_classes = num_classes
12371237
self.head.reset(num_classes, global_pool)
12381238

timm/models/focalnet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ def set_grad_checkpointing(self, enable=True):
454454
def get_classifier(self):
455455
return self.head.fc
456456

457-
def reset_classifier(self, num_classes, global_pool=None):
457+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
458458
self.head.reset(num_classes, pool_type=global_pool)
459459

460460
def forward_features(self, x):

timm/models/gcvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@ def set_grad_checkpointing(self, enable=True):
489489
def get_classifier(self):
490490
return self.head.fc
491491

492-
def reset_classifier(self, num_classes, global_pool=None):
492+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
493493
self.num_classes = num_classes
494494
if global_pool is None:
495495
global_pool = self.head.global_pool.pool_type

timm/models/levit.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -628,7 +628,7 @@ def set_grad_checkpointing(self, enable=True):
628628
def get_classifier(self):
629629
return self.head
630630

631-
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
631+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None):
632632
self.num_classes = num_classes
633633
if global_pool is not None:
634634
self.global_pool = global_pool
@@ -730,7 +730,7 @@ def __init__(self, *args, **kwargs):
730730
def get_classifier(self):
731731
return self.head, self.head_dist
732732

733-
def reset_classifier(self, num_classes, global_pool=None, distillation=None):
733+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None, distillation=None):
734734
self.num_classes = num_classes
735735
if global_pool is not None:
736736
self.global_pool = global_pool

timm/models/maxxvit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1248,7 +1248,7 @@ def set_grad_checkpointing(self, enable=True):
12481248
def get_classifier(self):
12491249
return self.head.fc
12501250

1251-
def reset_classifier(self, num_classes, global_pool=None):
1251+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
12521252
self.num_classes = num_classes
12531253
self.head.reset(num_classes, global_pool)
12541254

timm/models/mlp_mixer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def set_grad_checkpointing(self, enable=True):
255255
def get_classifier(self):
256256
return self.head
257257

258-
def reset_classifier(self, num_classes, global_pool=None):
258+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
259259
self.num_classes = num_classes
260260
if global_pool is not None:
261261
assert global_pool in ('', 'avg')

timm/models/mvitv2.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def set_grad_checkpointing(self, enable=True):
825825
def get_classifier(self):
826826
return self.head.fc
827827

828-
def reset_classifier(self, num_classes, global_pool=None):
828+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
829829
self.num_classes = num_classes
830830
if global_pool is not None:
831831
self.global_pool = global_pool

timm/models/nextvit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
"""
77
# Copyright (c) ByteDance Inc. All rights reserved.
88
from functools import partial
9+
from typing import Optional
910

1011
import torch
1112
import torch.nn.functional as F
@@ -553,7 +554,7 @@ def set_grad_checkpointing(self, enable=True):
553554
def get_classifier(self):
554555
return self.head.fc
555556

556-
def reset_classifier(self, num_classes, global_pool=None):
557+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
557558
self.head.reset(num_classes, pool_type=global_pool)
558559

559560
def forward_features(self, x):

timm/models/pit.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,13 @@
1414
import math
1515
import re
1616
from functools import partial
17-
from typing import Sequence, Tuple
17+
from typing import Optional, Sequence, Tuple
1818

1919
import torch
2020
from torch import nn
2121

2222
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
23-
from timm.layers import trunc_normal_, to_2tuple, LayerNorm
23+
from timm.layers import trunc_normal_, to_2tuple
2424
from ._builder import build_model_with_cfg
2525
from ._registry import register_model, generate_default_cfgs
2626
from .vision_transformer import Block
@@ -246,7 +246,7 @@ def get_classifier(self):
246246
else:
247247
return self.head
248248

249-
def reset_classifier(self, num_classes, global_pool=None):
249+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
250250
self.num_classes = num_classes
251251
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
252252
if self.head_dist is not None:

timm/models/pvt_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717

1818
import math
19-
from typing import Tuple, List, Callable, Union
19+
from typing import Callable, List, Optional, Union
2020

2121
import torch
2222
import torch.nn as nn
@@ -379,7 +379,7 @@ def set_grad_checkpointing(self, enable=True):
379379
def get_classifier(self):
380380
return self.head
381381

382-
def reset_classifier(self, num_classes, global_pool=None):
382+
def reset_classifier(self, num_classes: int, global_pool: Optional[str] = None):
383383
self.num_classes = num_classes
384384
if global_pool is not None:
385385
assert global_pool in ('avg', '')

0 commit comments

Comments
 (0)