Skip to content

Commit cb3e5b7

Browse files
committed
Forgot to compact attention pool branches after verifying
1 parent 1299488 commit cb3e5b7

File tree

1 file changed

+9
-20
lines changed

1 file changed

+9
-20
lines changed

timm/models/naflexvit.py

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1192,27 +1192,16 @@ def _pool(
11921192
patch_valid: Optional[torch.Tensor] = None,
11931193
) -> torch.Tensor:
11941194
if self.attn_pool is not None:
1195-
# For attention pooling, we need to pass the mask for NaFlex models
1195+
attn_mask = create_attention_mask(
1196+
patch_valid,
1197+
num_prefix_tokens=self.num_prefix_tokens if self.pool_include_prefix else 0,
1198+
symmetric=False,
1199+
q_len=1,
1200+
dtype=x.dtype,
1201+
)
11961202
if self.pool_include_prefix:
1197-
# Include all tokens in attention pooling - create mask for all tokens including prefix
1198-
attn_mask = create_attention_mask(
1199-
patch_valid,
1200-
num_prefix_tokens=self.num_prefix_tokens,
1201-
symmetric=False,
1202-
q_len=1,
1203-
dtype=x.dtype,
1204-
)
1205-
x = self.attn_pool(x, attn_mask=attn_mask)
1206-
else:
1207-
# Exclude prefix tokens from attention pooling (default behavior)
1208-
attn_mask = create_attention_mask(
1209-
patch_valid,
1210-
num_prefix_tokens=0, # No prefix tokens when we slice them off
1211-
symmetric=False,
1212-
q_len=1,
1213-
dtype=x.dtype,
1214-
)
1215-
x = self.attn_pool(x[:, self.num_prefix_tokens:], attn_mask=attn_mask)
1203+
x = x[:, self.num_prefix_tokens:]
1204+
x = self.attn_pool(x, attn_mask=attn_mask)
12161205
return x
12171206

12181207
pool_type = self.global_pool if pool_type is None else pool_type

0 commit comments

Comments
 (0)