Skip to content

Commit aa4d06a

Browse files
committed
sbb vit weights on hub, testing
1 parent 3582ca4 commit aa4d06a

File tree

2 files changed

+32
-41
lines changed

2 files changed

+32
-41
lines changed

timm/models/eva.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -717,11 +717,10 @@ def checkpoint_filter_fn(
717717
# fixed embedding no need to load buffer from checkpoint
718718
continue
719719

720-
# FIXME here while import new weights, to remove
721-
if k == 'cls_token':
722-
print('DEBUG: cls token -> reg')
723-
k = 'reg_token'
724-
#v = v + state_dict['pos_embed'][0, :]
720+
# FIXME here while importing new weights, to remove
721+
# if k == 'cls_token':
722+
# print('DEBUG: cls token -> reg')
723+
# k = 'reg_token'
725724

726725
if 'patch_embed.proj.weight' in k:
727726
_, _, H, W = model.patch_embed.proj.weight.shape
@@ -952,25 +951,25 @@ def _cfg(url='', **kwargs):
952951
),
953952

954953
'vit_medium_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
955-
#hf_hub_id='timm/',
956-
file='vit_medium_gap1_rope-in1k-20230920-5.pth',
954+
hf_hub_id='timm/',
955+
#file='vit_medium_gap1_rope-in1k-20230920-5.pth',
957956
input_size=(3, 256, 256), crop_pct=0.95,
958957
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
959958
),
960959
'vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
961-
#hf_hub_id='timm/',
962-
file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
960+
hf_hub_id='timm/',
961+
#file='vit_mediumd_gap1_rope-in1k-20230926-5.pth',
963962
input_size=(3, 256, 256), crop_pct=0.95,
964963
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
965964
),
966965
'vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k': _cfg(
967-
#hf_hub_id='timm/',
968-
file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
966+
hf_hub_id='timm/',
967+
#file='vit_betwixt_gap4_rope-in1k-20231005-5.pth',
969968
input_size=(3, 256, 256), crop_pct=0.95,
970969
),
971970
'vit_base_patch16_rope_reg1_gap_256.sbb_in1k': _cfg(
972-
#hf_hub_id='timm/',
973-
file='vit_base_gap1_rope-in1k-20230930-5.pth',
971+
hf_hub_id='timm/',
972+
#file='vit_base_gap1_rope-in1k-20230930-5.pth',
974973
input_size=(3, 256, 256), crop_pct=0.95,
975974
mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)
976975
),

timm/models/vision_transformer.py

Lines changed: 20 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,6 @@ def __init__(
428428
act_layer: Optional[LayerType] = None,
429429
block_fn: Type[nn.Module] = Block,
430430
mlp_layer: Type[nn.Module] = Mlp,
431-
repr_size = False,
432431
) -> None:
433432
"""
434433
Args:
@@ -537,14 +536,6 @@ def __init__(
537536
)
538537
else:
539538
self.attn_pool = None
540-
if repr_size:
541-
repr_size = self.embed_dim if isinstance(repr_size, bool) else repr_size
542-
self.repr = nn.Sequential(nn.Linear(self.embed_dim, repr_size), nn.Tanh())
543-
embed_dim = repr_size
544-
print(self.repr)
545-
else:
546-
self.repr = nn.Identity()
547-
548539
self.fc_norm = norm_layer(embed_dim) if use_fc_norm else nn.Identity()
549540
self.head_drop = nn.Dropout(drop_rate)
550541
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@@ -761,7 +752,6 @@ def forward_head(self, x: torch.Tensor, pre_logits: bool = False) -> torch.Tenso
761752
x = x[:, self.num_prefix_tokens:].mean(dim=1)
762753
elif self.global_pool:
763754
x = x[:, 0] # class token
764-
x = self.repr(x)
765755
x = self.fc_norm(x)
766756
x = self.head_drop(x)
767757
return x if pre_logits else self.head(x)
@@ -1804,35 +1794,45 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
18041794
#file='',
18051795
input_size=(3, 256, 256), crop_pct=0.95),
18061796
'vit_pwee_patch16_reg1_gap_256.sbb_in1k': _cfg(
1807-
file='./vit_pwee-in1k-8.pth',
1797+
#file='./vit_pwee-in1k-8.pth',
1798+
hf_hub_id='timm/',
18081799
input_size=(3, 256, 256), crop_pct=0.95),
18091800
'vit_little_patch16_reg4_gap_256.sbb_in1k': _cfg(
1810-
file='vit_little_patch16-in1k-8a.pth',
1801+
#file='vit_little_patch16-in1k-8a.pth',
1802+
hf_hub_id='timm/',
18111803
input_size=(3, 256, 256), crop_pct=0.95),
18121804
'vit_medium_patch16_reg1_gap_256.sbb_in1k': _cfg(
1813-
file='vit_medium_gap1-in1k-20231118-8.pth',
1805+
#file='vit_medium_gap1-in1k-20231118-8.pth',
1806+
hf_hub_id='timm/',
18141807
input_size=(3, 256, 256), crop_pct=0.95),
18151808
'vit_medium_patch16_reg4_gap_256.sbb_in1k': _cfg(
1816-
file='vit_medium_gap4-in1k-20231115-8.pth',
1809+
#file='vit_medium_gap4-in1k-20231115-8.pth',
1810+
hf_hub_id='timm/',
18171811
input_size=(3, 256, 256), crop_pct=0.95),
18181812
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
1819-
file='vit_mp_patch16_reg4-in1k-5a.pth',
1813+
#file='vit_mp_patch16_reg4-in1k-5a.pth',
1814+
hf_hub_id='timm/',
18201815
input_size=(3, 256, 256), crop_pct=0.95),
18211816
'vit_mediumd_patch16_reg4_gap_256.sbb_in12k': _cfg(
1822-
file='vit_mp_patch16_reg4-in12k-8.pth',
1817+
#file='vit_mp_patch16_reg4-in12k-8.pth',
1818+
hf_hub_id='timm/',
18231819
num_classes=11821,
18241820
input_size=(3, 256, 256), crop_pct=0.95),
18251821
'vit_betwixt_patch16_reg1_gap_256.sbb_in1k': _cfg(
1826-
file='vit_betwixt_gap1-in1k-20231121-8.pth',
1822+
#file='vit_betwixt_gap1-in1k-20231121-8.pth',
1823+
hf_hub_id='timm/',
18271824
input_size=(3, 256, 256), crop_pct=0.95),
18281825
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k': _cfg(
1829-
file='vit_betwixt_patch16_reg4-ft-in1k-8b.pth',
1826+
#file='vit_betwixt_patch16_reg4-ft-in1k-8b.pth',
1827+
hf_hub_id='timm/',
18301828
input_size=(3, 256, 256), crop_pct=0.95),
18311829
'vit_betwixt_patch16_reg4_gap_256.sbb_in1k': _cfg(
1832-
file='vit_betwixt_gap4-in1k-20231106-8.pth',
1830+
#file='vit_betwixt_gap4-in1k-20231106-8.pth',
1831+
hf_hub_id='timm/',
18331832
input_size=(3, 256, 256), crop_pct=0.95),
18341833
'vit_betwixt_patch16_reg4_gap_256.sbb_in12k': _cfg(
1835-
file='vit_betwixt_gap4-in12k-8.pth',
1834+
#file='vit_betwixt_gap4-in12k-8.pth',
1835+
hf_hub_id='timm/',
18361836
num_classes=11821,
18371837
input_size=(3, 256, 256), crop_pct=0.95),
18381838
'vit_base_patch16_reg4_gap_256': _cfg(
@@ -1933,14 +1933,6 @@ def vit_small_patch16_224(pretrained: bool = False, **kwargs) -> VisionTransform
19331933
return model
19341934

19351935

1936-
@register_model
1937-
def vit_small_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
1938-
""" ViT-Small (ViT-S/16)
1939-
"""
1940-
model_args = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6, global_pool='avg', class_token=False, repr_size=True)
1941-
model = _create_vision_transformer('vit_small_patch16_224', pretrained=pretrained, **dict(model_args, **kwargs))
1942-
return model
1943-
19441936
@register_model
19451937
def vit_small_patch16_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
19461938
""" ViT-Small (ViT-S/16)

0 commit comments

Comments
 (0)