Skip to content

Commit 3582ca4

Browse files
committed
Prepping weight push, benchmarking.
1 parent 2bfa5e5 commit 3582ca4

File tree

3 files changed

+83
-26
lines changed

3 files changed

+83
-26
lines changed

timm/models/eva.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -718,10 +718,10 @@ def checkpoint_filter_fn(
718718
continue
719719

720720
# FIXME here while import new weights, to remove
721-
# if k == 'cls_token':
722-
# print('DEBUG: cls token -> reg')
723-
# k = 'reg_token'
724-
# #v = v + state_dict['pos_embed'][0, :]
721+
if k == 'cls_token':
722+
print('DEBUG: cls token -> reg')
723+
k = 'reg_token'
724+
#v = v + state_dict['pos_embed'][0, :]
725725

726726
if 'patch_embed.proj.weight' in k:
727727
_, _, H, W = model.patch_embed.proj.weight.shape
@@ -951,26 +951,26 @@ def _cfg(url='', **kwargs):
951951
num_classes=0,
952952
),
953953

954-
'vit_medium_patch16_rope_reg1_gap_256.in1k': _cfg(
954+
'vit_medium_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
955955
#hf_hub_id='timm/',
956-
#file='vit_medium_gap1_rope-in1k-20230920-5.pth',
956+
file='vit_medium_gap1_rope-in1k-20230920-5.pth',
957957
input_size=(3, 256, 256), crop_pct=0.95,
958958
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
959959
),
960-
'vit_mediumd_patch16_rope_reg1_gap_256.in1k': _cfg(
960+
'vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
961961
#hf_hub_id='timm/',
962-
#file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
962+
file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
963963
input_size=(3, 256, 256), crop_pct=0.95,
964964
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
965965
),
966-
'vit_betwixt_patch16_rope_reg4_gap_256.in1k': _cfg(
966+
'vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k': _cfg(
967967
#hf_hub_id='timm/',
968-
#file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
968+
file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
969969
input_size=(3, 256, 256), crop_pct=0.95,
970970
),
971-
'vit_base_patch16_rope_reg1_gap_256.in1k': _cfg(
971+
'vit_base_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
972972
#hf_hub_id='timm/',
973-
#file='vit_base_gap1_rope-in1k-20230930-5.pth',
973+
file='vit_base_gap1_rope-in1k-20230930-5.pth',
974974
input_size=(3, 256, 256), crop_pct=0.95,
975975
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
976976
),

timm/models/vision_transformer.py

Lines changed: 70 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -428,6 +428,7 @@ def __init__(
428428
act_layer: Optional[LayerType] = None,
429429
block_fn: Type[nn.Module] = Block,
430430
mlp_layer: Type[nn.Module] = Mlp,
431+
repr_size = False,
431432
) -> None:
432433
"""
433434
Args:
@@ -536,6 +537,14 @@ def __init__(
536537
)
537538
else:
538539
self.attn_pool = None
540+
if repr_size:
541+
repr_size = self.embed_dim if isinstance(repr_size, bool) else repr_size
542+
self.repr = nn.Sequential(nn.Linear(self.embed_dim, repr_size), nn.Tanh())
543+
embed_dim = repr_size
544+
print(self.repr)
545+
else:
546+
self.repr = nn.Identity()
547+
539548
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
540549
self.head_drop = nn.Dropout(drop_rate)
541550
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@@ -752,6 +761,7 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso
752761
x = x[:, self.num_prefix_tokens:].mean(dim=1)
753762
elif self.global_pool:
754763
x = x[:, 0] # class token
764+
x = self.repr(x)
755765
x = self.fc_norm(x)
756766
x = self.head_drop(x)
757767
return x if pre_logits else self.head(x)
@@ -1790,23 +1800,40 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
17901800
license='mit',
17911801
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
17921802

1793-
'vit_wee_patch16_reg1_gap_256': _cfg(
1803+
'vit_wee_patch16_reg1_gap_256.sbb_in1k': _cfg(
17941804
#file='',
17951805
input_size=(3, 256, 256), crop_pct=0.95),
1796-
'vit_little_patch16_reg4_gap_256': _cfg(
1797-
#file='',
1806+
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
1807+
file='./vit_pwee-in1k-8.pth',
1808+
input_size=(3, 256, 256), crop_pct=0.95),
1809+
'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg(
1810+
file='vit_little_patch16-in1k-8a.pth',
1811+
input_size=(3, 256, 256), crop_pct=0.95),
1812+
'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
1813+
file='vit_medium_gap1-in1k-20231118-8.pth',
1814+
input_size=(3, 256, 256), crop_pct=0.95),
1815+
'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg(
1816+
file='vit_medium_gap4-in1k-20231115-8.pth',
1817+
input_size=(3, 256, 256), crop_pct=0.95),
1818+
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
1819+
file='vit_mp_patch16_reg4-in1k-5a.pth',
1820+
input_size=(3, 256, 256), crop_pct=0.95),
1821+
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
1822+
file='vit_mp_patch16_reg4-in12k-8.pth',
1823+
num_classes=11821,
17981824
input_size=(3, 256, 256), crop_pct=0.95),
1799-
'vit_medium_patch16_reg1_gap_256': _cfg(
1800-
#file='vit_medium_gap1-in1k-20231118-8.pth',
1825+
'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg(
1826+
file='vit_betwixt_gap1-in1k-20231121-8.pth',
18011827
input_size=(3, 256, 256), crop_pct=0.95),
1802-
'vit_medium_patch16_reg4_gap_256': _cfg(
1803-
#file='vit_medium_gap4-in1k-20231115-8.pth',
1828+
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
1829+
file='vit_betwixt_patch16_reg4-ft-in1k-8b.pth',
18041830
input_size=(3, 256, 256), crop_pct=0.95),
1805-
'vit_betwixt_patch16_reg1_gap_256': _cfg(
1806-
#file='vit_betwixt_gap1-in1k-20231121-8.pth',
1831+
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
1832+
file='vit_betwixt_gap4-in1k-20231106-8.pth',
18071833
input_size=(3, 256, 256), crop_pct=0.95),
1808-
'vit_betwixt_patch16_reg4_gap_256': _cfg(
1809-
#file='vit_betwixt_gap4-in1k-20231106-8.pth',
1834+
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
1835+
file='vit_betwixt_gap4-in12k-8.pth',
1836+
num_classes=11821,
18101837
input_size=(3, 256, 256), crop_pct=0.95),
18111838
'vit_base_patch16_reg4_gap_256': _cfg(
18121839
input_size=(3, 256, 256)),
@@ -1906,6 +1933,14 @@ def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransform
19061933
return model
19071934

19081935

1936+
@register_model
1937+
def vit_small_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
1938+
""" ViT-Small (ViT-S/16)
1939+
"""
1940+
model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, global_pool='avg', class_token=False, repr_size=True)
1941+
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
1942+
return model
1943+
19091944
@register_model
19101945
def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
19111946
""" ViT-Small (ViT-S/16)
@@ -2755,10 +2790,21 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
27552790
def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
27562791
model_args = dict(
27572792
patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
2793+
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg',
2794+
)
2795+
model = _create_vision_transformer(
2796+
'vit_wee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
2797+
return model
2798+
2799+
2800+
@register_model
2801+
def vit_pwee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
2802+
model_args = dict(
2803+
patch_size=16, embed_dim=256, depth=16, num_heads=4, init_values=1e-5, mlp_ratio=5,
27582804
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
27592805
)
27602806
model = _create_vision_transformer(
2761-
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
2807+
'vit_pwee_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
27622808
return model
27632809

27642810

@@ -2769,7 +2815,7 @@ def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
27692815
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
27702816
)
27712817
model = _create_vision_transformer(
2772-
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
2818+
'vit_little_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
27732819
return model
27742820

27752821

@@ -2795,6 +2841,17 @@ def vit_medium_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> Visio
27952841
return model
27962842

27972843

2844+
@register_model
2845+
def vit_mediumd_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
2846+
model_args = dict(
2847+
patch_size=16, embed_dim=512, depth=20, num_heads=8, init_values=1e-5,
2848+
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
2849+
)
2850+
model = _create_vision_transformer(
2851+
'vit_mediumd_patch16_reg4_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
2852+
return model
2853+
2854+
27982855
@register_model
27992856
def vit_betwixt_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
28002857
model_args = dict(

timm/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.0.0.dev0'
1+
__version__ = '1.0.1.dev0'

0 commit comments

Comments
 (0)