Skip to content

Commit 2ca45a4

Browse files
committed
Merge remote-tracking branch 'upstream/main' into hiera
2 parents c6db404 + 49de391 commit 2ca45a4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+2005
-472
lines changed

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,31 @@
2626
* The Hugging Face Hub (https://huggingface.co/timm) is now the primary source for `timm` weights. Model cards include link to papers, original source, license.
2727
* Previous 0.6.x can be cloned from [0.6.x](https://github.com/rwightman/pytorch-image-models/tree/0.6.x) branch or installed via pip with version.
2828

29+
### May 11, 2024
30+
* `Searching for Better ViT Baselines (For the GPU Poor)` weights and vit variants released. Exploring model shapes between Tiny and Base.
31+
32+
| model | top1 | top5 | param_count | img_size |
33+
| -------------------------------------------------- | ------ | ------ | ----------- | -------- |
34+
| [vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k](https://huggingface.co/timm/vit_mediumd_patch16_reg4_gap_256.sbb_in12k_ft_in1k) | 86.202 | 97.874 | 64.11 | 256 |
35+
| [vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k](https://huggingface.co/timm/vit_betwixt_patch16_reg4_gap_256.sbb_in12k_ft_in1k) | 85.418 | 97.48 | 60.4 | 256 |
36+
| [vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_mediumd_patch16_rope_reg1_gap_256.sbb_in1k) | 84.322 | 96.812 | 63.95 | 256 |
37+
| [vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k](https://huggingface.co/timm/vit_betwixt_patch16_rope_reg4_gap_256.sbb_in1k) | 83.906 | 96.684 | 60.23 | 256 |
38+
| [vit_base_patch16_rope_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_base_patch16_rope_reg1_gap_256.sbb_in1k) | 83.866 | 96.67 | 86.43 | 256 |
39+
| [vit_medium_patch16_rope_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_medium_patch16_rope_reg1_gap_256.sbb_in1k) | 83.81 | 96.824 | 38.74 | 256 |
40+
| [vit_betwixt_patch16_reg4_gap_256.sbb_in1k](https://huggingface.co/timm/vit_betwixt_patch16_reg4_gap_256.sbb_in1k) | 83.706 | 96.616 | 60.4 | 256 |
41+
| [vit_betwixt_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_betwixt_patch16_reg1_gap_256.sbb_in1k) | 83.628 | 96.544 | 60.4 | 256 |
42+
| [vit_medium_patch16_reg4_gap_256.sbb_in1k](https://huggingface.co/timm/vit_medium_patch16_reg4_gap_256.sbb_in1k) | 83.47 | 96.622 | 38.88 | 256 |
43+
| [vit_medium_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_medium_patch16_reg1_gap_256.sbb_in1k) | 83.462 | 96.548 | 38.88 | 256 |
44+
| [vit_little_patch16_reg4_gap_256.sbb_in1k](https://huggingface.co/timm/vit_little_patch16_reg4_gap_256.sbb_in1k) | 82.514 | 96.262 | 22.52 | 256 |
45+
| [vit_pwee_patch16_reg1_gap_256.sbb_in1k](https://huggingface.co/timm/vit_pwee_patch16_reg1_gap_256.sbb_in1k) | 80.072 | 95.136 | 15.25 | 256 |
46+
| [vit_mediumd_patch16_reg4_gap_256.sbb_in12k](https://huggingface.co/timm/vit_mediumd_patch16_reg4_gap_256.sbb_in12k) | N/A | N/A | 64.11 | 256 |
47+
| [vit_betwixt_patch16_reg4_gap_256.sbb_in12k](https://huggingface.co/timm/vit_betwixt_patch16_reg4_gap_256.sbb_in12k) | N/A | N/A | 60.4 | 256 |
48+
49+
* AttentionExtract helper added to extract attention maps from `timm` models. See example in https://github.com/huggingface/pytorch-image-models/discussions/1232#discussioncomment-9320949
50+
* `forward_intermediates()` API refined and added to more models including some ConvNets that have other extraction methods.
51+
* 1017 of 1047 model architectures support `features_only=True` feature extraction. Remaining 34 architectures can be supported but based on priority requests.
52+
* Remove torch.jit.script annotated functions including old JIT activations. Conflict with dynamo and dynamo does a much better job when used.
53+
2954
### April 11, 2024
3055
* Prepping for a long overdue 1.0 release, things have been stable for a while now.
3156
* Significant feature that's been missing for a while, `features_only=True` support for ViT models with flat hidden states or non-std module layouts (so far covering `'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*'`)

hfdocs/source/feature_extraction.mdx

Lines changed: 80 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The features from the penultimate model layer can be obtained in several ways wi
88

99
### Unpooled
1010

11-
There are three ways to obtain unpooled features.
11+
There are three ways to obtain unpooled features. The final, unpooled features are sometimes referred to as the last hidden state. In `timm` this is up to and including the final normalization layer (in e.g. ViT style models) but does not include pooling / class token selection and final post-pooling layers.
1212

1313
Without modifying the network, one can call `model.forward_features(input)` on any model instead of the usual `model(input)`. This will bypass the head classifier and global pooling for networks.
1414

@@ -69,6 +69,25 @@ Original shape: torch.Size([2, 1000])
6969
Unpooled shape: torch.Size([2, 1024, 7, 7])
7070
```
7171

72+
#### Chaining unpooled output to classifier
73+
74+
The last hidden state can be fed back into the head of the model using the `forward_head()` function.
75+
76+
```py
77+
>>> model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
78+
>>> output = model.forward_features(torch.randn(2,3,256,256))
79+
>>> print('Unpooled output shape:', output.shape)
80+
>>> classified = model.forward_head(output)
81+
>>> print('Classification output shape:', classified.shape)
82+
```
83+
84+
Output:
85+
86+
```text
87+
Unpooled output shape: torch.Size([2, 257, 512])
88+
Classification output shape: torch.Size([2, 1000])
89+
```
90+
7291
### Pooled
7392

7493
To modify the network to return pooled features, one can use `forward_features()` and pool/flatten the result themselves, or modify the network like above but keep pooling intact.
@@ -116,7 +135,7 @@ Object detection, segmentation, keypoint, and a variety of dense pixel tasks req
116135

117136
`timm` allows a consistent interface for creating any of the included models as feature backbones that output feature maps for selected levels.
118137

119-
A feature backbone can be created by adding the argument `features_only=True` to any `create_model` call. By default 5 strides will be output from most models (not all have that many), with the first starting at 2 (some start at 1 or 4).
138+
A feature backbone can be created by adding the argument `features_only=True` to any `create_model` call. By default most models with a feature hierarchy will output up to 5 features up to a reduction of 32. However this varies per model, some models have fewer hierarchy levels, and some (like ViT) have a larger number of non-hierarchical feature maps and they default to outputting the last 3. The `out_indices` arg can be passed to `create_model` to specify which features you want.
120139

121140
### Create a feature map extraction model
122141

@@ -171,7 +190,13 @@ There are two additional creation arguments impacting the output features.
171190
* `out_indices` selects which indices to output
172191
* `output_stride` limits the feature output stride of the network (also works in classification mode BTW)
173192

174-
`out_indices` is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most models, index 0 is the stride 2 features, and index 4 is stride 32.
193+
#### Output index selection
194+
195+
The `out_indices` argument is supported by all models, but not all models have the same index to feature stride mapping. Look at the code or check feature_info to compare. The out indices generally correspond to the `C(i+1)th` feature level (a `2^(i+1)` reduction). For most convnet models, index 0 is the stride 2 features, and index 4 is stride 32. For many ViT or ViT-Conv hybrids there may be many to all features maps of the same shape, or a combination of hierarchical and non-hieararchical feature maps. It is best to look at the `feature_info` attribute to see the number of features, their corresponding channel count and reduction level.
196+
197+
`out_indices` supports negative indexing, this makes it easy to get the last, penunltimate, etc feature map. `out_indices=(-2,)` would return the penultimate feature map for any model.
198+
199+
#### Output stride (feature map dilation)
175200

176201
`output_stride` is achieved by converting layers to use dilated convolutions. Doing so is not always straightforward, some networks only support `output_stride=32`.
177202

@@ -194,3 +219,55 @@ Feature reduction: [8, 8]
194219
torch.Size([2, 512, 40, 40])
195220
torch.Size([2, 2048, 40, 40])
196221
```
222+
223+
## Flexible intermediate feature map extraction
224+
225+
In addition to using `features_only` with the model factory, many models support a `forward_intermediates()` method which provides a flexible mechanism for extracting both the intermediate feature maps and the last hidden state (which can be chained to the head). Additionally this method supports some model specific features such as returning class or distill prefix tokens for some models.
226+
227+
Accompanying the `forward_intermediates` function is a `prune_intermediate_layers` function that allows one to prune layers from the model, including both the head, final norm, and/or trailing blocks/stages that are not needed.
228+
229+
An `indices` argument is used for both `forward_intermediates()` and `prune_intermediate_layers()` to select the features to return or layers to remove. As with the `out_indices` for `features_only` API, `indices` is model specific and selects which intermediates are returned.
230+
231+
In non-hierarchical block based models such as ViT the indices correspond to the blocks, in models with hierarchical stages they usually correspond to the output of the stem + each hierarhical stage. Both positive (from the start), and negative (relative to the end) indexing works, and `None` is used to return all intermediates.
232+
233+
The `prune_intermediate_layers()` call returns an indices variable, as negative indices must be converted to absolute (positive) indices when the model is trimmed.
234+
235+
```py
236+
model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
237+
output, intermediates = model.forward_intermediates(torch.randn(2,3,256,256))
238+
for i, o in enumerate(intermediates):
239+
print(f'Feat index: {i}, shape: {o.shape}')
240+
```
241+
242+
```text
243+
Feat index: 0, shape: torch.Size([2, 512, 16, 16])
244+
Feat index: 1, shape: torch.Size([2, 512, 16, 16])
245+
Feat index: 2, shape: torch.Size([2, 512, 16, 16])
246+
Feat index: 3, shape: torch.Size([2, 512, 16, 16])
247+
Feat index: 4, shape: torch.Size([2, 512, 16, 16])
248+
Feat index: 5, shape: torch.Size([2, 512, 16, 16])
249+
Feat index: 6, shape: torch.Size([2, 512, 16, 16])
250+
Feat index: 7, shape: torch.Size([2, 512, 16, 16])
251+
Feat index: 8, shape: torch.Size([2, 512, 16, 16])
252+
Feat index: 9, shape: torch.Size([2, 512, 16, 16])
253+
Feat index: 10, shape: torch.Size([2, 512, 16, 16])
254+
Feat index: 11, shape: torch.Size([2, 512, 16, 16])
255+
```
256+
257+
```py
258+
model = timm.create_model('vit_medium_patch16_reg1_gap_256', pretrained=True)
259+
print('Original params:', sum([p.numel() for p in model.parameters()]))
260+
261+
indices = model.prune_intermediate_layers(indices=(-2,), prune_head=True, prune_norm=True) # prune head, norm, last block
262+
print('Pruned params:', sum([p.numel() for p in model.parameters()]))
263+
264+
intermediates = model.forward_intermediates(torch.randn(2,3,256,256), indices=indices, intermediates_only=True) # return penultimate intermediate
265+
for o in intermediates:
266+
print(f'Feat shape: {o.shape}')
267+
```
268+
269+
```text
270+
Original params: 38880232
271+
Pruned params: 35212800
272+
Feat shape: torch.Size([2, 512, 16, 16])
273+
```

tests/test_models.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,13 @@
4949

5050
# models with forward_intermediates() and support for FeatureGetterNet features_only wrapper
5151
FEAT_INTER_FILTERS = [
52-
'vit_*', 'twins_*', 'deit*', 'beit*', 'mvitv2*', 'eva*', 'samvit_*', 'flexivit*', 'hiera_*'
52+
'vision_transformer', 'vision_transformer_sam', 'vision_transformer_hybrid', 'vision_transformer_relpos',
53+
'beit', 'mvitv2', 'eva', 'cait', 'xcit', 'volo', 'twins', 'deit', 'swin_transformer', 'swin_transformer_v2',
54+
'swin_transformer_v2_cr', 'maxxvit', 'efficientnet', 'mobilenetv3', 'levit', 'efficientformer', 'resnet',
55+
'regnet', 'byobnet', 'byoanet', 'mlp_mixer', 'hiera',
5356
]
5457

55-
# transformer models don't support many of the spatial / feature based model functionalities
58+
# transformer / hybrid models don't support full set of spatial / feature APIs and/or have spatial output.
5659
NON_STD_FILTERS = [
5760
'vit_*', 'tnt_*', 'pit_*', 'coat_*', 'cait_*', '*mixer_*', 'gmlp_*', 'resmlp_*', 'twins_*',
5861
'convit_*', 'levit*', 'visformer*', 'deit*', 'xcit_*', 'crossvit_*', 'beit*',
@@ -387,13 +390,12 @@ def test_model_forward_features(model_name, batch_size):
387390

388391
@pytest.mark.features
389392
@pytest.mark.timeout(120)
390-
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
393+
@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
391394
@pytest.mark.parametrize('batch_size', [1])
392395
def test_model_forward_intermediates_features(model_name, batch_size):
393396
"""Run a single forward pass with each model in feature extraction mode"""
394-
model = create_model(model_name, pretrained=False, features_only=True)
397+
model = create_model(model_name, pretrained=False, features_only=True, feature_cls='getter')
395398
model.eval()
396-
print(model.feature_info.out_indices)
397399
expected_channels = model.feature_info.channels()
398400
expected_reduction = model.feature_info.reduction()
399401

@@ -419,7 +421,7 @@ def test_model_forward_intermediates_features(model_name, batch_size):
419421

420422
@pytest.mark.features
421423
@pytest.mark.timeout(120)
422-
@pytest.mark.parametrize('model_name', list_models(FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS))
424+
@pytest.mark.parametrize('model_name', list_models(module=FEAT_INTER_FILTERS, exclude_filters=EXCLUDE_FILTERS + ['*pruned*']))
423425
@pytest.mark.parametrize('batch_size', [1])
424426
def test_model_forward_intermediates(model_name, batch_size):
425427
"""Run a single forward pass with each model in feature extraction mode"""
@@ -428,18 +430,19 @@ def test_model_forward_intermediates(model_name, batch_size):
428430
feature_info = timm.models.FeatureInfo(model.feature_info, len(model.feature_info))
429431
expected_channels = feature_info.channels()
430432
expected_reduction = feature_info.reduction()
431-
assert len(expected_channels) >= 4 # all models here should have at least 4 feature levels by default, some 5 or 6
433+
assert len(expected_channels) >= 3 # all models here should have at least 3 feature levels
432434

433435
input_size = _get_input_size(model=model, target=TARGET_FFEAT_SIZE)
434436
if max(input_size) > MAX_FFEAT_SIZE:
435437
pytest.skip("Fixed input size model > limit.")
436-
output_fmt = getattr(model, 'output_fmt', 'NCHW')
438+
output_fmt = 'NCHW' # NOTE output_fmt determined by forward_intermediates() arg, not model attribute
437439
feat_axis = get_channel_dim(output_fmt)
438440
spatial_axis = get_spatial_dim(output_fmt)
439441
import math
440442

441443
output, intermediates = model.forward_intermediates(
442444
torch.randn((batch_size, *input_size)),
445+
output_fmt=output_fmt,
443446
)
444447
assert len(expected_channels) == len(intermediates)
445448
spatial_size = input_size[-2:]
@@ -483,7 +486,7 @@ def _create_fx_model(model, train=False):
483486
return fx_model
484487

485488

486-
EXCLUDE_FX_FILTERS = ['vit_gi*']
489+
EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*']
487490
# not enough memory to run fx on more models than other tests
488491
if 'GITHUB_ACTIONS' in os.environ:
489492
EXCLUDE_FX_FILTERS += [

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
from .squeeze_excite import SEModule, SqueezeExcite, EffectiveSEModule, EffectiveSqueezeExcite
4848
from .selective_kernel import SelectiveKernel
4949
from .separable_conv import SeparableConv2d, SeparableConvNormAct
50-
from .space_to_depth import SpaceToDepthModule, SpaceToDepth, DepthToSpace
50+
from .space_to_depth import SpaceToDepth, DepthToSpace
5151
from .split_attn import SplitAttn
5252
from .split_batchnorm import SplitBatchNorm2d, convert_splitbn_model
5353
from .std_conv import StdConv2d, StdConv2dSame, ScaledStdConv2d, ScaledStdConv2dSame

timm/layers/activations_jit.py

Lines changed: 0 additions & 90 deletions
This file was deleted.

0 commit comments

Comments
 (0)