Skip to content

Commit c22efb9

Browse files
committed
Add wee & little vits for some experiments
1 parent 67332fc commit c22efb9

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

timm/models/vision_transformer.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,6 +1791,12 @@ def _cfg(url: str = '', **kwargs) -> Dict[str, Any]:
17911791
license='mit',
17921792
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
17931793

1794+
'vit_wee_patch16_reg1_gap_256': _cfg(
1795+
file='',
1796+
input_size=(3, 256, 256), crop_pct=0.95),
1797+
'vit_little_patch16_reg4_gap_256': _cfg(
1798+
file='',
1799+
input_size=(3, 256, 256), crop_pct=0.95),
17941800
'vit_medium_patch16_reg1_gap_256': _cfg(
17951801
file='vit_medium_gap1-in1k-20231118-8.pth',
17961802
input_size=(3, 256, 256), crop_pct=0.95),
@@ -2746,6 +2752,28 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
27462752
# return model
27472753

27482754

2755+
@register_model
2756+
def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
2757+
model_args = dict(
2758+
patch_size=16, embed_dim=256, depth=14, num_heads=4, init_values=1e-5, mlp_ratio=5,
2759+
class_token=False, no_embed_class=True, reg_tokens=1, global_pool='avg', block_fn=ParallelScalingBlock,
2760+
)
2761+
model = _create_vision_transformer(
2762+
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
2763+
return model
2764+
2765+
2766+
@register_model
2767+
def vit_little_patch16_reg4_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
2768+
model_args = dict(
2769+
patch_size=16, embed_dim=320, depth=14, num_heads=5, init_values=1e-5, mlp_ratio=5.6,
2770+
class_token=False, no_embed_class=True, reg_tokens=4, global_pool='avg',
2771+
)
2772+
model = _create_vision_transformer(
2773+
'vit_medium_patch16_reg1_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
2774+
return model
2775+
2776+
27492777
@register_model
27502778
def vit_medium_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
27512779
model_args = dict(

0 commit comments

Comments
 (0)