@@ -428,7 +428,6 @@ def __init__(
428
428
act_layer : Optional [LayerType ] = None ,
429
429
block_fn : Type [nn .Module ] = Block ,
430
430
mlp_layer : Type [nn .Module ] = Mlp ,
431
- repr_size = False ,
432
431
) -> None :
433
432
"""
434
433
Args:
@@ -537,14 +536,6 @@ def __init__(
537
536
)
538
537
else :
539
538
self .attn_pool = None
540
- if repr_size :
541
- repr_size = self .embed_dim if isinstance (repr_size , bool ) else repr_size
542
- self .repr = nn .Sequential (nn .Linear (self .embed_dim , repr_size ), nn .Tanh ())
543
- embed_dim = repr_size
544
- print (self .repr )
545
- else :
546
- self .repr = nn .Identity ()
547
-
548
539
self .fc_norm = norm_layer (embed_dim ) if use_fc_norm else nn .Identity ()
549
540
self .head_drop = nn .Dropout (drop_rate )
550
541
self .head = nn .Linear (self .embed_dim , num_classes ) if num_classes > 0 else nn .Identity ()
@@ -761,7 +752,6 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso
761
752
x = x [:, self .num_prefix_tokens :].mean (dim = 1 )
762
753
elif self .global_pool :
763
754
x = x [:, 0 ] # class token
764
- x = self .repr (x )
765
755
x = self .fc_norm (x )
766
756
x = self .head_drop (x )
767
757
return x if pre_logits else self .head (x )
@@ -1804,35 +1794,45 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
1804
1794
#file='',
1805
1795
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1806
1796
'vit_pwee_patch16_reg1_gap_256.sbb_in1k' : _cfg (
1807
- file = './vit_pwee-in1k-8.pth' ,
1797
+ #file='./vit_pwee-in1k-8.pth',
1798
+ hf_hub_id = 'timm/' ,
1808
1799
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1809
1800
'vit_little_patch16_reg4_gap_256.sbb_in1k' : _cfg (
1810
- file = 'vit_little_patch16-in1k-8a.pth' ,
1801
+ #file='vit_little_patch16-in1k-8a.pth',
1802
+ hf_hub_id = 'timm/' ,
1811
1803
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1812
1804
'vit_medium_patch16_reg1_gap_256.sbb_in1k' : _cfg (
1813
- file = 'vit_medium_gap1-in1k-20231118-8.pth' ,
1805
+ #file='vit_medium_gap1-in1k-20231118-8.pth',
1806
+ hf_hub_id = 'timm/' ,
1814
1807
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1815
1808
'vit_medium_patch16_reg4_gap_256.sbb_in1k' : _cfg (
1816
- file = 'vit_medium_gap4-in1k-20231115-8.pth' ,
1809
+ #file='vit_medium_gap4-in1k-20231115-8.pth',
1810
+ hf_hub_id = 'timm/' ,
1817
1811
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1818
1812
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k' : _cfg (
1819
- file = 'vit_mp_patch16_reg4-in1k-5a.pth' ,
1813
+ #file='vit_mp_patch16_reg4-in1k-5a.pth',
1814
+ hf_hub_id = 'timm/' ,
1820
1815
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1821
1816
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k' : _cfg (
1822
- file = 'vit_mp_patch16_reg4-in12k-8.pth' ,
1817
+ #file='vit_mp_patch16_reg4-in12k-8.pth',
1818
+ hf_hub_id = 'timm/' ,
1823
1819
num_classes = 11821 ,
1824
1820
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1825
1821
'vit_betwixt_patch16_reg1_gap_256.sbb_in1k' : _cfg (
1826
- file = 'vit_betwixt_gap1-in1k-20231121-8.pth' ,
1822
+ #file='vit_betwixt_gap1-in1k-20231121-8.pth',
1823
+ hf_hub_id = 'timm/' ,
1827
1824
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1828
1825
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k' : _cfg (
1829
- file = 'vit_betwixt_patch16_reg4-ft-in1k-8b.pth' ,
1826
+ #file='vit_betwixt_patch16_reg4-ft-in1k-8b.pth',
1827
+ hf_hub_id = 'timm/' ,
1830
1828
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1831
1829
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k' : _cfg (
1832
- file = 'vit_betwixt_gap4-in1k-20231106-8.pth' ,
1830
+ #file='vit_betwixt_gap4-in1k-20231106-8.pth',
1831
+ hf_hub_id = 'timm/' ,
1833
1832
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1834
1833
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k' : _cfg (
1835
- file = 'vit_betwixt_gap4-in12k-8.pth' ,
1834
+ #file='vit_betwixt_gap4-in12k-8.pth',
1835
+ hf_hub_id = 'timm/' ,
1836
1836
num_classes = 11821 ,
1837
1837
input_size = (3 , 256 , 256 ), crop_pct = 0.95 ),
1838
1838
'vit_base_patch16_reg4_gap_256' : _cfg (
@@ -1933,14 +1933,6 @@ def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransform
1933
1933
return model
1934
1934
1935
1935
1936
- @register_model
1937
- def vit_small_patch16_gap_224 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
1938
- """ ViT-Small (ViT-S/16)
1939
- """
1940
- model_args = dict (patch_size = 16 , embed_dim = 384 , depth = 12 , num_heads = 6 , global_pool = 'avg' , class_token = False , repr_size = True )
1941
- model = _create_vision_transformer ('vit_small_patch16_224' , pretrained = pretrained , ** dict (model_args , ** kwargs ))
1942
- return model
1943
-
1944
1936
@register_model
1945
1937
def vit_small_patch16_384 (pretrained : bool = False , ** kwargs ) -> VisionTransformer :
1946
1938
""" ViT-Small (ViT-S/16)
0 commit comments