@@ -428,6 +428,7 @@ def __init__(
428
428
act_layer : Optional [LayerType ] = None ,
429
429
block_fn : Type [nn .Module ] = Block ,
430
430
mlp_layer : Type [nn .Module ] = Mlp ,
431
+ repr_size = False ,
431
432
) -> None :
432
433
"""
433
434
Args:
@@ -536,6 +537,14 @@ def __init__(
536
537
)
537
538
else :
538
539
self .attn_pool = None
540
+ if repr_size :
541
+ repr_size = self .embed_dim if isinstance (repr_size , bool ) else repr_size
542
+ self .repr = nn .Sequential (nn .Linear (self .embed_dim , repr_size ), nn .Tanh ())
543
+ embed_dim = repr_size
544
+ print (self .repr )
545
+ else :
546
+ self .repr = nn .Identity ()
547
+
539
548
self .fc_norm = norm_layer (embed_dim ) if use_fc_norm else nn .Identity ()
540
549
self .head_drop = nn .Dropout (drop_rate )
541
550
self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
@@ -752,6 +761,7 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso
752
761
x = x [:, self .num_prefix_tokens :].mean (dim = 1 )
753
762
elif self .global_pool :
754
763
x = x [:, 0 ] # class token
764
+ x = self .repr (x )
755
765
x = self .fc_norm (x )
756
766
x = self .head_drop (x )
757
767
return x if pre_logits else self .head (x )
@@ -1790,23 +1800,40 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1790
1800
license = 'mit' ,
1791
1801
mean = OPENAI_CLIP_MEAN , std = OPENAI_CLIP_STD , num_classes = 512 ),
1792
1802
1793
- 'vit_wee_patch16_reg1_gap_256' : _cfg (
1803
+ 'vit_wee_patch16_reg1_gap_256.sbb_in1k ' : _cfg (
1794
1804
#file='',
1795
1805
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1796
- 'vit_little_patch16_reg4_gap_256' : _cfg (
1797
- #file='',
1806
+ 'vit_pwee_patch16_reg1_gap_256.sbb_in1k' : _cfg (
1807
+ file = './vit_pwee-in1k-8.pth' ,
1808
+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1809
+ 'vit_little_patch16_reg4_gap_256.sbb_in1k' : _cfg (
1810
+ file = 'vit_little_patch16-in1k-8a.pth' ,
1811
+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1812
+ 'vit_medium_patch16_reg1_gap_256.sbb_in1k' : _cfg (
1813
+ file = 'vit_medium_gap1-in1k-20231118-8.pth' ,
1814
+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1815
+ 'vit_medium_patch16_reg4_gap_256.sbb_in1k' : _cfg (
1816
+ file = 'vit_medium_gap4-in1k-20231115-8.pth' ,
1817
+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1818
+ 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k' : _cfg (
1819
+ file = 'vit_mp_patch16_reg4-in1k-5a.pth' ,
1820
+ input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1821
+ 'vit_mediumd_patch16_reg4_gap_256.sbb_in12k' : _cfg (
1822
+ file = 'vit_mp_patch16_reg4-in12k-8.pth' ,
1823
+ num_classes = 11821 ,
1798
1824
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1799
- 'vit_medium_patch16_reg1_gap_256 ' : _cfg (
1800
- # file='vit_medium_gap1 -in1k-20231118 -8.pth',
1825
+ 'vit_betwixt_patch16_reg1_gap_256.sbb_in1k ' : _cfg (
1826
+ file = 'vit_betwixt_gap1 -in1k-20231121 -8.pth' ,
1801
1827
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1802
- 'vit_medium_patch16_reg4_gap_256 ' : _cfg (
1803
- # file='vit_medium_gap4- in1k-20231115-8 .pth',
1828
+ 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k ' : _cfg (
1829
+ file = 'vit_betwixt_patch16_reg4-ft- in1k-8b .pth' ,
1804
1830
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1805
- 'vit_betwixt_patch16_reg1_gap_256 ' : _cfg (
1806
- # file='vit_betwixt_gap1 -in1k-20231121 -8.pth',
1831
+ 'vit_betwixt_patch16_reg4_gap_256.sbb_in1k ' : _cfg (
1832
+ file = 'vit_betwixt_gap4 -in1k-20231106 -8.pth' ,
1807
1833
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1808
- 'vit_betwixt_patch16_reg4_gap_256' : _cfg (
1809
- #file='vit_betwixt_gap4-in1k-20231106-8.pth',
1834
+ 'vit_betwixt_patch16_reg4_gap_256.sbb_in12k' : _cfg (
1835
+ file = 'vit_betwixt_gap4-in12k-8.pth' ,
1836
+ num_classes = 11821 ,
1810
1837
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1811
1838
'vit_base_patch16_reg4_gap_256' : _cfg (
1812
1839
input_size = (3 , 256 , 256 )),
@@ -1906,6 +1933,14 @@ def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransform
1906
1933
return model
1907
1934
1908
1935
1936
+ @register_model
1937
+ def vit_small_patch16_gap_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
1938
+ """ ViT-Small (ViT-S/16)
1939
+ """
1940
+ model_args = dict (patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , global_pool = 'avg' , class_token = False , repr_size = True )
1941
+ model = _create_vision_transformer ('vit_small_patch16_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1942
+ return model
1943
+
1909
1944
@register_model
1910
1945
def vit_small_patch16_384 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
1911
1946
""" ViT-Small (ViT-S/16)
@@ -2755,10 +2790,21 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
2755
2790
def vit_wee_patch16_reg1_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2756
2791
model_args = dict (
2757
2792
patch_size = 16 , embed_dim = 256 , depth = 14 , num_heads = 4 , init_values = 1e-5 , mlp_ratio = 5 ,
2793
+ class_token = False , no_embed_class = True , reg_tokens = 1 , global_pool = 'avg' ,
2794
+ )
2795
+ model = _create_vision_transformer (
2796
+ 'vit_wee_patch16_reg1_gap_256' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2797
+ return model
2798
+
2799
+
2800
+ @register_model
2801
+ def vit_pwee_patch16_reg1_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2802
+ model_args = dict (
2803
+ patch_size = 16 , embed_dim = 256 , depth = 16 , num_heads = 4 , init_values = 1e-5 , mlp_ratio = 5 ,
2758
2804
class_token = False , no_embed_class = True , reg_tokens = 1 , global_pool = 'avg' , block_fn = ParallelScalingBlock ,
2759
2805
)
2760
2806
model = _create_vision_transformer (
2761
- 'vit_medium_patch16_reg1_gap_256 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2807
+ 'vit_pwee_patch16_reg1_gap_256 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2762
2808
return model
2763
2809
2764
2810
@@ -2769,7 +2815,7 @@ def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
2769
2815
class_token = False , no_embed_class = True , reg_tokens = 4 , global_pool = 'avg' ,
2770
2816
)
2771
2817
model = _create_vision_transformer (
2772
- 'vit_medium_patch16_reg1_gap_256 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2818
+ 'vit_little_patch16_reg4_gap_256 ' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2773
2819
return model
2774
2820
2775
2821
@@ -2795,6 +2841,17 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
2795
2841
return model
2796
2842
2797
2843
2844
+ @register_model
2845
+ def vit_mediumd_patch16_reg4_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2846
+ model_args = dict (
2847
+ patch_size = 16 , embed_dim = 512 , depth = 20 , num_heads = 8 , init_values = 1e-5 ,
2848
+ class_token = False , no_embed_class = True , reg_tokens = 4 , global_pool = 'avg' ,
2849
+ )
2850
+ model = _create_vision_transformer (
2851
+ 'vit_mediumd_patch16_reg4_gap_256' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
2852
+ return model
2853
+
2854
+
2798
2855
@register_model
2799
2856
def vit_betwixt_patch16_reg1_gap_256 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
2800
2857
model_args = dict (
0 commit comments