Skip to content

Commit 02d29de

Browse files
committed
Add global_pool to mambaout __init__ and pass to heads
1 parent b0cfd9d commit 02d29de

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

timm/models/mambaout.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ def __init__(
300300
self,
301301
in_chans=3,
302302
num_classes=1000,
303+
global_pool='avg',
303304
depths=(3, 3, 9, 3),
304305
dims=(96, 192, 384, 576),
305306
norm_layer=LayerNorm,
@@ -369,7 +370,7 @@ def __init__(
369370
self.head = MlpHead(
370371
prev_dim,
371372
num_classes,
372-
pool_type='avg',
373+
pool_type=global_pool,
373374
drop_rate=drop_rate,
374375
norm_layer=norm_layer,
375376
)
@@ -379,7 +380,7 @@ def __init__(
379380
prev_dim,
380381
num_classes,
381382
hidden_size=int(prev_dim * 4),
382-
pool_type='avg',
383+
pool_type=global_pool,
383384
norm_layer=norm_layer,
384385
drop_rate=drop_rate,
385386
)

0 commit comments

Comments
 (0)