12
12
from torch import nn
13
13
14
14
from timm .data import IMAGENET_DEFAULT_MEAN , IMAGENET_DEFAULT_STD
15
- from timm .layers import trunc_normal_ , DropPath , LayerNorm , LayerScale , ClNormMlpClassifierHead
15
+ from timm .layers import trunc_normal_ , DropPath , LayerNorm , LayerScale , ClNormMlpClassifierHead , get_act_layer
16
16
from ._builder import build_model_with_cfg
17
17
from ._manipulate import checkpoint_seq
18
18
from ._registry import register_model
@@ -318,10 +318,12 @@ def __init__(
318
318
super ().__init__ ()
319
319
self .num_classes = num_classes
320
320
self .drop_rate = drop_rate
321
+ self .output_fmt = 'NHWC'
321
322
if not isinstance (depths , (list , tuple )):
322
323
depths = [depths ] # it means the model has only one stage
323
324
if not isinstance (dims , (list , tuple )):
324
325
dims = [dims ]
326
+ act_layer = get_act_layer (act_layer )
325
327
326
328
num_stage = len (depths )
327
329
self .num_stage = num_stage
@@ -456,7 +458,7 @@ def checkpoint_filter_fn(state_dict, model):
456
458
def _cfg (url = '' , ** kwargs ):
457
459
return {
458
460
'url' : url ,
459
- 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : None ,
461
+ 'num_classes' : 1000 , 'input_size' : (3 , 224 , 224 ), 'pool_size' : ( 7 , 7 ) ,
460
462
'crop_pct' : 1.0 , 'interpolation' : 'bicubic' ,
461
463
'mean' : IMAGENET_DEFAULT_MEAN , 'std' : IMAGENET_DEFAULT_STD , 'classifier' : 'head.fc' ,
462
464
** kwargs
@@ -477,6 +479,7 @@ def _cfg(url='', **kwargs):
477
479
'mambaout_small_rw' : _cfg (),
478
480
'mambaout_base_slim_rw' : _cfg (),
479
481
'mambaout_base_plus_rw' : _cfg (),
482
+ 'test_mambaout' : _cfg (input_size = (3 , 160 , 160 ), pool_size = (5 , 5 )),
480
483
}
481
484
482
485
@@ -554,9 +557,26 @@ def mambaout_base_plus_rw(pretrained=False, **kwargs):
554
557
depths = (3 , 4 , 27 , 3 ),
555
558
dims = (128 , 256 , 512 , 768 ),
556
559
expansion_ratio = 3.0 ,
560
+ conv_ratio = 1.5 ,
557
561
stem_mid_norm = False ,
558
562
downsample = 'conv_nf' ,
559
563
ls_init_value = 1e-6 ,
564
+ act_layer = 'silu' ,
560
565
head_fn = 'norm_mlp' ,
561
566
)
562
567
return _create_mambaout ('mambaout_base_plus_rw' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
568
+
569
+
570
+ @register_model
571
+ def test_mambaout (pretrained = False , ** kwargs ):
572
+ model_args = dict (
573
+ depths = (1 , 1 , 3 , 1 ),
574
+ dims = (16 , 32 , 48 , 64 ),
575
+ expansion_ratio = 3 ,
576
+ stem_mid_norm = False ,
577
+ downsample = 'conv_nf' ,
578
+ ls_init_value = 1e-4 ,
579
+ act_layer = 'silu' ,
580
+ head_fn = 'norm_mlp' ,
581
+ )
582
+ return _create_mambaout ('test_mambaout' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
0 commit comments