Skip to content

Commit f9b3d7e

Browse files
authored
Merge pull request #2507 from huggingface/more_naflex
Forgot to compact attention pool branches after verifying
2 parents 1299488 + 6c7ce45 commit f9b3d7e

File tree

2 files changed

+29
-21
lines changed

2 files changed

+29
-21
lines changed

README.md

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,25 @@
1212

1313
## What's New
1414

15+
## June 5, 2025
16+
* Initial NaFlexVit model code. NaFlexVit is a Vision Transformer with:
17+
1. Encapsulated embedding and position encoding in a single module
18+
2. Support for nn.Linear patch embedding on pre-patchified (dictionary) inputs
19+
3. Support for NaFlex variable aspect, variable resolution (SigLip-2: https://arxiv.org/abs/2502.14786)
20+
4. Support for FlexiViT variable patch size (https://arxiv.org/abs/2212.08013)
21+
5. Support for NaViT fractional/factorized position embedding (https://arxiv.org/abs/2307.06304)
22+
* Existing vit models in `vision_transformer.py` can be loaded into the NaFlexVit model by adding the `use_naflex=True` flag to `create_model`
23+
* Some native weights coming soon
24+
* A full NaFlex data pipeline is available that allows training / fine-tuning / evaluating with variable aspect / size images
25+
* To enable in `train.py` and `validate.py` add the `--naflex-loader` arg, must be used with a NaFlexVit
26+
* To evaluate an existing (classic) ViT loaded in NaFlexVit model w/ NaFlex data pipe:
27+
* `python validate.py /imagenet --amp -j 8 --model vit_base_patch16_224 --model-kwargs use_naflex=True --naflex-loader --naflex-max-seq-len 256`
28+
* The training has some extra args features worth noting
29+
* The `--naflex-train-seq-lens'` argument specifies which sequence lengths to randomly pick from per batch during training
30+
* The `--naflex-max-seq-len` argument sets the target sequence length for validation
31+
* Adding `--model-kwargs enable_patch_interpolator=True --naflex-patch-sizes 12 16 24` will enable random patch size selection per-batch w/ interpolation
32+
* The `--naflex-loss-scale` arg changes loss scaling mode per batch relative to the batch size, `timm` NaFlex loading changes the batch size for each seq len
33+
1534
## May 28, 2025
1635
* Add a number of small/fast models thanks to https://github.com/brianhou0208
1736
* SwiftFormer - [(ICCV2023) SwiftFormer: Efficient Additive Attention for Transformer-based Real-time Mobile Vision Applications](https://github.com/Amshaker/SwiftFormer)

timm/models/naflexvit.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,27 +1192,16 @@ def _pool(
11921192
patch_valid: Optional[torch.Tensor] = None,
11931193
) -> torch.Tensor:
11941194
if self.attn_pool is not None:
1195-
# For attention pooling, we need to pass the mask for NaFlex models
1196-
if self.pool_include_prefix:
1197-
# Include all tokens in attention pooling - create mask for all tokens including prefix
1198-
attn_mask = create_attention_mask(
1199-
patch_valid,
1200-
num_prefix_tokens=self.num_prefix_tokens,
1201-
symmetric=False,
1202-
q_len=1,
1203-
dtype=x.dtype,
1204-
)
1205-
x = self.attn_pool(x, attn_mask=attn_mask)
1206-
else:
1207-
# Exclude prefix tokens from attention pooling (default behavior)
1208-
attn_mask = create_attention_mask(
1209-
patch_valid,
1210-
num_prefix_tokens=0, # No prefix tokens when we slice them off
1211-
symmetric=False,
1212-
q_len=1,
1213-
dtype=x.dtype,
1214-
)
1215-
x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask)
1195+
attn_mask = create_attention_mask(
1196+
patch_valid,
1197+
num_prefix_tokens=self.num_prefix_tokens if self.pool_include_prefix else 0,
1198+
symmetric=False,
1199+
q_len=1,
1200+
dtype=x.dtype,
1201+
)
1202+
if not self.pool_include_prefix:
1203+
x = x[:, self.num_prefix_tokens:]
1204+
x = self.attn_pool(x, attn_mask=attn_mask)
12161205
return x
12171206

12181207
pool_type = self.global_pool if pool_type is None else pool_type

0 commit comments

Comments
 (0)