File tree Expand file tree Collapse file tree 1 file changed +9
-20
lines changed Expand file tree Collapse file tree 1 file changed +9
-20
lines changed Original file line number Diff line number Diff line change @@ -1192,27 +1192,16 @@ def _pool(
1192
1192
patch_valid : Optional [torch .Tensor ] = None ,
1193
1193
) -> torch .Tensor :
1194
1194
if self .attn_pool is not None :
1195
- # For attention pooling, we need to pass the mask for NaFlex models
1195
+ attn_mask = create_attention_mask (
1196
+ patch_valid ,
1197
+ num_prefix_tokens = self .num_prefix_tokens if self .pool_include_prefix else 0 ,
1198
+ symmetric = False ,
1199
+ q_len = 1 ,
1200
+ dtype = x .dtype ,
1201
+ )
1196
1202
if self .pool_include_prefix :
1197
- # Include all tokens in attention pooling - create mask for all tokens including prefix
1198
- attn_mask = create_attention_mask (
1199
- patch_valid ,
1200
- num_prefix_tokens = self .num_prefix_tokens ,
1201
- symmetric = False ,
1202
- q_len = 1 ,
1203
- dtype = x .dtype ,
1204
- )
1205
- x = self .attn_pool (x , attn_mask = attn_mask )
1206
- else :
1207
- # Exclude prefix tokens from attention pooling (default behavior)
1208
- attn_mask = create_attention_mask (
1209
- patch_valid ,
1210
- num_prefix_tokens = 0 , # No prefix tokens when we slice them off
1211
- symmetric = False ,
1212
- q_len = 1 ,
1213
- dtype = x .dtype ,
1214
- )
1215
- x = self .attn_pool (x [:, self .num_prefix_tokens :], attn_mask = attn_mask )
1203
+ x = x [:, self .num_prefix_tokens :]
1204
+ x = self .attn_pool (x , attn_mask = attn_mask )
1216
1205
return x
1217
1206
1218
1207
pool_type = self .global_pool if pool_type is None else pool_type
You can’t perform that action at this time.
0 commit comments