Skip to content

Commit 9567cf6

Browse files
authored
Feature: add option global_pool='max' to VisionTransformer
Most of the CNNs have a max global pooling option. I would like to extend ViT to have this option.
1 parent 22de845 commit 9567cf6

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

timm/models/vision_transformer.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def __init__(
400400
patch_size: Union[int, Tuple[int, int]] = 16,
401401
in_chans: int = 3,
402402
num_classes: int = 1000,
403-
global_pool: Literal['', 'avg', 'token', 'map'] = 'token',
403+
global_pool: Literal['', 'avg', 'max', 'token', 'map'] = 'token',
404404
embed_dim: int = 768,
405405
depth: int = 12,
406406
num_heads: int = 12,
@@ -459,10 +459,10 @@ def __init__(
459459
block_fn: Transformer block layer.
460460
"""
461461
super().__init__()
462-
assert global_pool in ('', 'avg', 'token', 'map')
462+
assert global_pool in ('', 'avg', 'max', 'token', 'map')
463463
assert class_token or global_pool != 'token'
464464
assert pos_embed in ('', 'none', 'learn')
465-
use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
465+
use_fc_norm = global_pool in ['avg', 'max'] if fc_norm is None else fc_norm
466466
norm_layer = get_norm_layer(norm_layer) or partial(nn.LayerNorm, eps=1e-6)
467467
act_layer = get_act_layer(act_layer) or nn.GELU
468468

@@ -761,6 +761,8 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso
761761
x = self.attn_pool(x)
762762
elif self.global_pool == 'avg':
763763
x = x[:, self.num_prefix_tokens:].mean(dim=1)
764+
elif self.global_pool == 'max':
765+
x, _ = torch.max(x[:, self.num_prefix_tokens:], dim=1)
764766
elif self.global_pool:
765767
x = x[:, 0] # class token
766768
x = self.fc_norm(x)

0 commit comments

Comments
 (0)