@@ -400,7 +400,7 @@ def __init__(
400
400
patch_size : Union [int , Tuple [int , int ]] = 16 ,
401
401
in_chans : int = 3 ,
402
402
num_classes : int = 1000 ,
403
- global_pool : Literal ['' , 'avg' , 'token' , 'map' ] = 'token' ,
403
+ global_pool : Literal ['' , 'avg' , 'max' , ' token' , 'map' ] = 'token' ,
404
404
embed_dim : int = 768 ,
405
405
depth : int = 12 ,
406
406
num_heads : int = 12 ,
@@ -459,10 +459,10 @@ def __init__(
459
459
block_fn: Transformer block layer.
460
460
"""
461
461
super ().__init__ ()
462
- assert global_pool in ('' , 'avg' , 'token' , 'map' )
462
+ assert global_pool in ('' , 'avg' , 'max' , ' token' , 'map' )
463
463
assert class_token or global_pool != 'token'
464
464
assert pos_embed in ('' , 'none' , 'learn' )
465
- use_fc_norm = global_pool == 'avg' if fc_norm is None else fc_norm
465
+ use_fc_norm = global_pool in [ 'avg' , 'max' ] if fc_norm is None else fc_norm
466
466
norm_layer = get_norm_layer (norm_layer ) or partial (nn .LayerNorm , eps = 1e-6 )
467
467
act_layer = get_act_layer (act_layer ) or nn .GELU
468
468
@@ -761,6 +761,8 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso
761
761
x = self .attn_pool (x )
762
762
elif self .global_pool == 'avg' :
763
763
x = x [:, self .num_prefix_tokens :].mean (dim = 1 )
764
+ elif self .global_pool == 'max' :
765
+ x , _ = torch .max (x [:, self .num_prefix_tokens :], dim = 1 )
764
766
elif self .global_pool :
765
767
x = x [:, 0 ] # class token
766
768
x = self .fc_norm (x )
0 commit comments