Skip to content

Commit cc8a03d

Browse files
committed
Add ConvStem and MobileCLIP hybrid model for B variant. Add full norm disable support to ConvNormAct layers
1 parent 3c9d8e5 commit cc8a03d

File tree

4 files changed

+112
-37
lines changed

4 files changed

+112
-37
lines changed

timm/layers/conv_bn_act.py

Lines changed: 37 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def __init__(
2323
dilation: int = 1,
2424
groups: int = 1,
2525
bias: bool = False,
26+
apply_norm: bool = True,
2627
apply_act: bool = True,
2728
norm_layer: LayerType = nn.BatchNorm2d,
2829
act_layer: LayerType = nn.ReLU,
@@ -48,17 +49,23 @@ def __init__(
4849
**conv_kwargs,
4950
)
5051

51-
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
52-
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
53-
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
54-
if drop_layer:
55-
norm_kwargs['drop_layer'] = drop_layer
56-
self.bn = norm_act_layer(
57-
out_channels,
58-
apply_act=apply_act,
59-
act_kwargs=act_kwargs,
60-
**norm_kwargs,
61-
)
52+
if apply_norm:
53+
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
54+
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
55+
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
56+
if drop_layer:
57+
norm_kwargs['drop_layer'] = drop_layer
58+
self.bn = norm_act_layer(
59+
out_channels,
60+
apply_act=apply_act,
61+
act_kwargs=act_kwargs,
62+
**norm_kwargs,
63+
)
64+
else:
65+
self.bn = nn.Sequential()
66+
if drop_layer:
67+
norm_kwargs['drop_layer'] = drop_layer
68+
self.bn.add_module('drop', drop_layer())
6269

6370
@property
6471
def in_channels(self):
@@ -88,6 +95,7 @@ def __init__(
8895
dilation: int = 1,
8996
groups: int = 1,
9097
bias: bool = False,
98+
apply_norm: bool = True,
9199
apply_act: bool = True,
92100
norm_layer: LayerType = nn.BatchNorm2d,
93101
act_layer: LayerType = nn.ReLU,
@@ -113,17 +121,24 @@ def __init__(
113121
**conv_kwargs,
114122
)
115123

116-
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
117-
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
118-
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
119-
if drop_layer:
120-
norm_kwargs['drop_layer'] = drop_layer
121-
self.bn = norm_act_layer(
122-
out_channels,
123-
apply_act=apply_act,
124-
act_kwargs=act_kwargs,
125-
**norm_kwargs,
126-
)
124+
if apply_norm:
125+
# NOTE for backwards compatibility with models that use separate norm and act layer definitions
126+
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
127+
# NOTE for backwards (weight) compatibility, norm layer name remains `.bn`
128+
if drop_layer:
129+
norm_kwargs['drop_layer'] = drop_layer
130+
self.bn = norm_act_layer(
131+
out_channels,
132+
apply_act=apply_act,
133+
act_kwargs=act_kwargs,
134+
**norm_kwargs,
135+
)
136+
else:
137+
self.bn = nn.Sequential()
138+
if drop_layer:
139+
norm_kwargs['drop_layer'] = drop_layer
140+
self.bn.add_module('drop', drop_layer())
141+
127142
self.aa = create_aa(aa_layer, out_channels, stride=stride, enable=use_aa)
128143

129144
@property

timm/layers/norm_act.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,21 +19,18 @@
1919
from torch.nn import functional as F
2020
from torchvision.ops.misc import FrozenBatchNorm2d
2121

22-
from .create_act import get_act_layer
22+
from .create_act import create_act_layer
2323
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
2424
from .trace_utils import _assert
2525

2626

2727
def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
28-
act_layer = get_act_layer(act_layer) # string -> nn.Module
2928
act_kwargs = act_kwargs or {}
30-
if act_layer is not None and apply_act:
31-
if inplace:
32-
act_kwargs['inplace'] = inplace
33-
act = act_layer(**act_kwargs)
34-
else:
35-
act = nn.Identity()
36-
return act
29+
act_kwargs.setdefault('inplace', inplace)
30+
act = None
31+
if apply_act:
32+
act = create_act_layer(act_layer, **act_kwargs)
33+
return nn.Identity() if act is None else act
3734

3835

3936
class BatchNormAct2d(nn.BatchNorm2d):
@@ -421,7 +418,6 @@ def __init__(
421418
):
422419
super(LayerNormAct, self).__init__(normalization_shape, eps=eps, elementwise_affine=affine)
423420
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
424-
act_layer = get_act_layer(act_layer) # string -> nn.Module
425421
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
426422

427423
self._fast_norm = is_fast_norm()

timm/models/vision_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ def reset_classifier(self, num_classes: int, global_pool = None) -> None:
609609

610610
def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
611611
if self.pos_embed is None:
612-
return x
612+
return x.view(x.shape[0], -1, x.shape[-1])
613613

614614
if self.dynamic_img_size:
615615
B, H, W, C = x.shape

timm/models/vision_transformer_hybrid.py

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
"""
1616
import math
1717
from functools import partial
18-
from typing import List, Optional, Tuple, Union
18+
from typing import List, Optional, Tuple, Type, Union
1919

2020
import torch
2121
import torch.nn as nn
2222
import torch.nn.functional as F
2323

2424
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
25-
from timm.layers import StdConv2dSame, StdConv2d, to_2tuple, Format, nchw_to
25+
from timm.layers import StdConv2dSame, StdConv2d, ConvNormAct, to_2tuple, to_ntuple, Format, nchw_to
26+
2627
from ._registry import generate_default_cfgs, register_model, register_model_deprecations
2728
from .resnet import resnet26d, resnet50d
2829
from .resnetv2 import ResNetV2, create_resnetv2_stem
@@ -191,8 +192,52 @@ def forward(self, x) -> Tuple[torch.Tensor, List[int]]:
191192
return x.flatten(2).transpose(1, 2), x.shape[-2:]
192193

193194

194-
def _create_vision_transformer_hybrid(variant, backbone, pretrained=False, **kwargs):
195-
embed_layer = partial(HybridEmbed, backbone=backbone)
195+
class ConvStem(nn.Sequential):
196+
def __init__(
197+
self,
198+
in_chans: int = 3,
199+
depth: int = 3,
200+
channels: Union[int, Tuple[int, ...]] = 64,
201+
kernel_size: Union[int, Tuple[int, ...]] = 3,
202+
stride: Union[int, Tuple[int, ...]] = (2, 2, 2),
203+
padding: Union[str, int, Tuple[int, ...]] = "",
204+
norm_layer: Type[nn.Module] = nn.BatchNorm2d,
205+
act_layer: Type[nn.Module] = nn.ReLU,
206+
):
207+
super().__init__()
208+
if isinstance(channels, int):
209+
if depth == 4:
210+
channels = (channels // 8, channels // 4, channels // 2, channels)
211+
elif depth == 3:
212+
channels = (channels // 4, channels // 2, channels)
213+
else:
214+
channels = to_ntuple(depth)(channels)
215+
216+
kernel_size = to_ntuple(depth)(kernel_size)
217+
padding = to_ntuple(depth)(padding)
218+
assert depth == len(stride) == len(kernel_size) == len(channels)
219+
220+
in_chs = in_chans
221+
for i in range(len(channels)):
222+
last_conv = i == len(channels) - 1
223+
self.add_module(f'{i}', ConvNormAct(
224+
in_chs,
225+
channels[i],
226+
kernel_size=kernel_size[i],
227+
stride=stride[i],
228+
padding=padding[i],
229+
bias=last_conv,
230+
apply_norm=not last_conv,
231+
apply_act=not last_conv,
232+
norm_layer=norm_layer,
233+
act_layer=act_layer,
234+
))
235+
in_chs = channels[i]
236+
237+
238+
def _create_vision_transformer_hybrid(variant, backbone, embed_args=None, pretrained=False, **kwargs):
239+
embed_args = embed_args or {}
240+
embed_layer = partial(HybridEmbed, backbone=backbone, **embed_args)
196241
kwargs.setdefault('patch_size', 1) # default patch size for hybrid models if not set
197242
return _create_vision_transformer(variant, pretrained=pretrained, embed_layer=embed_layer, **kwargs)
198243

@@ -433,6 +478,25 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs) -> VisionTransformer:
433478
return model
434479

435480

481+
@register_model
482+
def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer:
483+
""" Custom ViT base hybrid w/ ResNet50D stride 32. No pretrained weights.
484+
"""
485+
backbone = ConvStem(
486+
channels=(768//4, 768//4, 768),
487+
stride=(4, 2, 2),
488+
kernel_size=(4, 2, 2),
489+
padding=0,
490+
act_layer=nn.GELU,
491+
)
492+
model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)
493+
model = _create_vision_transformer_hybrid(
494+
'vit_base_resnet50d_224', backbone=backbone, embed_args=dict(proj=False),
495+
pretrained=pretrained, **dict(model_args, **kwargs)
496+
)
497+
return model
498+
499+
436500
register_model_deprecations(__name__, {
437501
'vit_tiny_r_s16_p8_224_in21k': 'vit_tiny_r_s16_p8_224.augreg_in21k',
438502
'vit_small_r26_s32_224_in21k': 'vit_small_r26_s32_224.augreg_in21k',

0 commit comments

Comments
 (0)