Skip to content

Commit 7a4e987

Browse files
committed
Hiera weights on hub
1 parent c838c42 commit 7a4e987

File tree

1 file changed

+27
-26
lines changed

1 file changed

+27
-26
lines changed

timm/models/hiera.py

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -803,62 +803,62 @@ def _cfg(url='', **kwargs):
803803

804804
default_cfgs = generate_default_cfgs({
805805
"hiera_tiny_224.mae_in1k_ft_in1k": _cfg(
806-
url="https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth",
807-
#hf_hb='timm/',
806+
hf_hub_id='timm/',
807+
license='cc-by-nc-4.0',
808808
),
809809
"hiera_tiny_224.mae": _cfg(
810-
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth",
811-
#hf_hb='timm/',
810+
hf_hub_id='timm/',
811+
license='cc-by-nc-4.0',
812812
num_classes=0,
813813
),
814814

815815
"hiera_small_224.mae_in1k_ft_in1k": _cfg(
816-
url="https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth",
817-
#hf_hb='timm/',
816+
hf_hub_id='timm/',
817+
license='cc-by-nc-4.0',
818818
),
819819
"hiera_small_224.mae": _cfg(
820-
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth",
821-
#hf_hb='timm/',
820+
hf_hub_id='timm/',
821+
license='cc-by-nc-4.0',
822822
num_classes=0,
823823
),
824824

825825
"hiera_base_224.mae_in1k_ft_in1k": _cfg(
826-
url="https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth",
827-
#hf_hb='timm/',
826+
hf_hub_id='timm/',
827+
license='cc-by-nc-4.0',
828828
),
829829
"hiera_base_224.mae": _cfg(
830-
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth",
831-
#hf_hb='timm/',
830+
hf_hub_id='timm/',
831+
license='cc-by-nc-4.0',
832832
num_classes=0,
833833
),
834834

835835
"hiera_base_plus_224.mae_in1k_ft_in1k": _cfg(
836-
url="https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth",
837-
#hf_hb='timm/',
836+
hf_hub_id='timm/',
837+
license='cc-by-nc-4.0',
838838
),
839839
"hiera_base_plus_224.mae": _cfg(
840-
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth",
841-
#hf_hb='timm/',
840+
hf_hub_id='timm/',
841+
license='cc-by-nc-4.0',
842842
num_classes=0,
843843
),
844844

845845
"hiera_large_224.mae_in1k_ft_in1k": _cfg(
846-
url="https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth",
847-
#hf_hb='timm/',
846+
hf_hub_id='timm/',
847+
license='cc-by-nc-4.0',
848848
),
849849
"hiera_large_224.mae": _cfg(
850-
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth",
851-
#hf_hb='timm/',
850+
hf_hub_id='timm/',
851+
license='cc-by-nc-4.0',
852852
num_classes=0,
853853
),
854854

855855
"hiera_huge_224.mae_in1k_ft_in1k": _cfg(
856-
url="https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth",
857-
#hf_hb='timm/',
856+
hf_hub_id='timm/',
857+
license='cc-by-nc-4.0',
858858
),
859859
"hiera_huge_224.mae": _cfg(
860-
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth",
861-
#hf_hb='timm/',
860+
hf_hub_id='timm/',
861+
license='cc-by-nc-4.0',
862862
num_classes=0,
863863
),
864864
})
@@ -880,7 +880,9 @@ def checkpoint_filter_fn(state_dict, model=None):
880880
pass
881881
if 'head.projection.' in k:
882882
k = k.replace('head.projection.', 'head.fc.')
883-
if k.startswith('norm.'):
883+
if k.startswith('encoder_norm.'):
884+
k = k.replace('encoder_norm.', 'head.norm.')
885+
elif k.startswith('norm.'):
884886
k = k.replace('norm.', 'head.norm.')
885887
output[k] = v
886888
return output
@@ -893,7 +895,6 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
893895
Hiera,
894896
variant,
895897
pretrained,
896-
#pretrained_strict=False,
897898
pretrained_filter_fn=checkpoint_filter_fn,
898899
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
899900
**kwargs,

0 commit comments

Comments
 (0)