Skip to content

Commit 91e743f

Browse files
committed
Mambaout tweaks
1 parent 4542cf0 commit 91e743f

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

timm/models/mambaout.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from torch import nn
1313

1414
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
15-
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead
15+
from timm.layers import trunc_normal_, DropPath, LayerNorm, LayerScale, ClNormMlpClassifierHead, get_act_layer
1616
from ._builder import build_model_with_cfg
1717
from ._manipulate import checkpoint_seq
1818
from ._registry import register_model
@@ -318,10 +318,12 @@ def __init__(
318318
super().__init__()
319319
self.num_classes = num_classes
320320
self.drop_rate = drop_rate
321+
self.output_fmt = 'NHWC'
321322
if not isinstance(depths, (list, tuple)):
322323
depths = [depths] # it means the model has only one stage
323324
if not isinstance(dims, (list, tuple)):
324325
dims = [dims]
326+
act_layer = get_act_layer(act_layer)
325327

326328
num_stage = len(depths)
327329
self.num_stage = num_stage
@@ -456,7 +458,7 @@ def checkpoint_filter_fn(state_dict, model):
456458
def _cfg(url='', **kwargs):
457459
return {
458460
'url': url,
459-
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': None,
461+
'num_classes': 1000, 'input_size': (3, 224, 224), 'pool_size': (7, 7),
460462
'crop_pct': 1.0, 'interpolation': 'bicubic',
461463
'mean': IMAGENET_DEFAULT_MEAN, 'std': IMAGENET_DEFAULT_STD, 'classifier': 'head.fc',
462464
**kwargs
@@ -477,6 +479,7 @@ def _cfg(url='', **kwargs):
477479
'mambaout_small_rw': _cfg(),
478480
'mambaout_base_slim_rw': _cfg(),
479481
'mambaout_base_plus_rw': _cfg(),
482+
'test_mambaout': _cfg(input_size=(3, 160, 160), pool_size=(5, 5)),
480483
}
481484

482485

@@ -554,9 +557,26 @@ def mambaout_base_plus_rw(pretrained=False, **kwargs):
554557
depths=(3, 4, 27, 3),
555558
dims=(128, 256, 512, 768),
556559
expansion_ratio=3.0,
560+
conv_ratio=1.5,
557561
stem_mid_norm=False,
558562
downsample='conv_nf',
559563
ls_init_value=1e-6,
564+
act_layer='silu',
560565
head_fn='norm_mlp',
561566
)
562567
return _create_mambaout('mambaout_base_plus_rw', pretrained=pretrained, **dict(model_args, **kwargs))
568+
569+
570+
@register_model
571+
def test_mambaout(pretrained=False, **kwargs):
572+
model_args = dict(
573+
depths=(1, 1, 3, 1),
574+
dims=(16, 32, 48, 64),
575+
expansion_ratio=3,
576+
stem_mid_norm=False,
577+
downsample='conv_nf',
578+
ls_init_value=1e-4,
579+
act_layer='silu',
580+
head_fn='norm_mlp',
581+
)
582+
return _create_mambaout('test_mambaout', pretrained=pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)