Skip to content

Initialize weights of reg_token for ViT #2229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2024
Merged

Initialize weights of reg_token for ViT #2229

merged 1 commit into from
Jul 18, 2024

Conversation

Promisery
Copy link
Contributor

The reg_tokens in VisionTransformer are not initialzed in the init_weights() function. Therefore, when setting reg_tokens>0 and no_embed_class=True, all reg_tokens are initialized as 0 and remain the same during training, unless ROPE is used to break the symmetry. As a result, all of the models from Searching for Better ViT Baselines with reg_tokens=4 have 4 reg_tokens of the same weight, except vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k. To verify this, run

import timm
models = [
    'vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k',
    'vit_medium_patch16_reg4_gap_256.sbb_in12k_ft_in1k',
    'vit_medium_patch16_reg4_gap_256.sbb_in1k',
    'vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k',
    'vit_betwixt_patch16_reg4_gap_256.sbb_in1k',
    'vit_little_patch16_reg4_gap_256.sbb_in1k',
    'vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k',
]
for m in models:
    model = timm.create_model(m, pretrained=True)
    print(m, (model.reg_token.std(dim=1) == 0).all())

Therefore, the reg_tokens should be randomly initialized in the init_weights() function.

@rwightman
Copy link
Collaborator

rwightman commented Jul 13, 2024

EDIT: right, see the point re symmetry, will try a comparative training run.

@Promisery
Copy link
Contributor Author

Promisery commented Jul 13, 2024

Sorry for the confusion. What I mean is that the reg_tokens are symmetrical with zero init and therefore always identical during the training process. For example

import timm
model = timm.create_model('vit_betwixt_patch16_reg4_gap_256.sbb_in1k', pretrained=True)
print(model.reg_token)

Parameter containing:
tensor([[[0.0019, 0.0179, 0.0019,  ..., 0.0019, 0.0020, 0.0018],
         [0.0019, 0.0179, 0.0019,  ..., 0.0019, 0.0020, 0.0018],
         [0.0019, 0.0179, 0.0019,  ..., 0.0019, 0.0020, 0.0018],
         [0.0019, 0.0179, 0.0019,  ..., 0.0019, 0.0020, 0.0018]]],
       requires_grad=True)

With ROPE, the relative position between the reg_tokens breaks the symmetry and therefore they are no longer identical.

import timm
model = timm.create_model('vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k', pretrained=True)
print(model.reg_token)

Parameter containing:
tensor([[[ 1.7490e-03,  1.7675e-03,  1.9662e-03,  ...,  1.7398e-03,
           1.7855e-03,  1.7709e-03],
         [ 7.6676e-04,  1.5031e-03,  9.7344e-04,  ...,  1.3273e-03,
           1.2984e-03,  8.5229e-04],
         [ 2.8474e-03,  2.8378e-03,  2.9330e-03,  ...,  2.8548e-03,
           2.8779e-03,  2.8636e-03],
         [-1.4168e-04, -1.2385e-04, -8.3081e-05,  ..., -1.3852e-04,
          -1.3595e-04, -1.2391e-04]]], requires_grad=True)

I'm not sure if reg_tokens being identical will damage performance, but I think it is not the expected behavior, right?

@rwightman
Copy link
Collaborator

rwightman commented Jul 18, 2024

@Promisery first 'little' model finished with updated init, it was behind in accuracy for most of the training and just pulled ahead by the tiniest amount at the very end. So, not a clear win over the existing, but the registers are different, it may have greater benefit in non-classification, feature repr, etc.

@rwightman rwightman merged commit 7160af4 into huggingface:main Jul 18, 2024
@rwightman
Copy link
Collaborator

@Promisery I have some in12k & in12k + 1k ft reg4 weights w/ the init fix there

https://huggingface.co/timm/vit_mediumd_patch16_reg4_gap_384.sbb2_e200_in12k_ft_in1k

Anything with '.sbb2' will have the fix in it. Not sure I'll get to all the reg > 1 models though. The difference in end performance is very very minor (loooks like a coin flip) wrt to accuracy on re-run. The improvements on these runs are due to other tweaks (longer pretrain epochs, some rejigging of fine-tune hparams, etc).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants