Skip to content

Commit 57adc1a

Browse files
committed
Fix rotary embed version of attn pool. Bit of cleanup/naming
1 parent cdc7bce commit 57adc1a

File tree

2 files changed

+17
-14
lines changed

2 files changed

+17
-14
lines changed

timm/layers/attention_pool2d.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
qkv_bias: bool = True,
4343
qkv_separate: bool = False,
4444
pool_type: str = 'token',
45-
avg_token: bool = True,
45+
class_token: bool = False,
4646
drop_rate: float = 0.,
4747
):
4848
super().__init__()
@@ -63,6 +63,11 @@ def __init__(
6363
self.scale = self.head_dim ** -0.5
6464
self.fused_attn = use_fused_attn()
6565

66+
if class_token:
67+
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
68+
else:
69+
self.cls_token = None
70+
6671
if qkv_separate:
6772
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
6873
self.k = nn.Linear(in_features, embed_dim, bias=qkv_bias)
@@ -109,7 +114,10 @@ def forward(self, x, pre_logits: bool = False):
109114
B, _, H, W = x.shape
110115
N = H * W
111116
x = x.flatten(2).transpose(1, 2)
112-
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
117+
if self.cls_token is None:
118+
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
119+
else:
120+
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
113121
if self.qkv is None:
114122
q = self.q(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
115123
k = self.k(x).reshape(B, N + 1, self.num_heads, self.head_dim).transpose(1, 2)
@@ -130,7 +138,6 @@ def forward(self, x, pre_logits: bool = False):
130138
attn = attn.softmax(dim=-1)
131139
x = attn @ v
132140
x = x.transpose(1, 2).reshape(B, N + 1, -1)
133-
x = x[:, 0]
134141
x = self.drop(x)
135142
if pre_logits:
136143
x = self._pool(x, H, W)
@@ -162,7 +169,7 @@ def __init__(
162169
qkv_bias: bool = True,
163170
qkv_separate: bool = False,
164171
pool_type: str = 'token',
165-
learned_token: bool = False,
172+
class_token: bool = False,
166173
drop_rate: float = 0.,
167174
):
168175
super().__init__()
@@ -184,10 +191,10 @@ def __init__(
184191
self.scale = self.head_dim ** -0.5
185192
self.fused_attn = use_fused_attn()
186193

187-
if learned_token:
188-
self.token = nn.Parameter(torch.zeros(1, embed_dim))
194+
if class_token:
195+
self.cls_token = nn.Parameter(torch.zeros(1, embed_dim))
189196
else:
190-
self.token = None
197+
self.cls_token = None
191198

192199
if qkv_separate:
193200
self.q = nn.Linear(in_features, embed_dim, bias=qkv_bias)
@@ -239,10 +246,10 @@ def forward(self, x, pre_logits: bool = False):
239246
B, _, H, W = x.shape
240247
N = H * W
241248
x = x.flatten(2).transpose(1, 2)
242-
if self.token is not None:
243-
x = torch.cat([self.token.expand(x.shape[0], -1, -1), x], dim=1)
244-
else:
249+
if self.cls_token is None:
245250
x = torch.cat([x.mean(1, keepdim=True), x], dim=1)
251+
else:
252+
x = torch.cat([self.cls_token.expand(x.shape[0], -1, -1), x], dim=1)
246253
pos_embed = resample_abs_pos_embed(self.pos_embed.unsqueeze(0), (H, W), num_prefix_tokens=1)
247254
x = x + pos_embed
248255

timm/models/byobnet.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1945,7 +1945,6 @@ def _init_weights(module, name='', zero_init_last=False):
19451945
downsample='avg',
19461946
aa_layer='avg',
19471947
head_type='attn_abs',
1948-
#head_hidden_size=512,
19491948
),
19501949

19511950
resnet50x4_clip=ByoModelCfg(
@@ -1962,7 +1961,6 @@ def _init_weights(module, name='', zero_init_last=False):
19621961
downsample='avg',
19631962
aa_layer='avg',
19641963
head_type='attn_abs',
1965-
#head_hidden_size=640,
19661964
),
19671965

19681966
resnet50x16_clip=ByoModelCfg(
@@ -1979,7 +1977,6 @@ def _init_weights(module, name='', zero_init_last=False):
19791977
downsample='avg',
19801978
aa_layer='avg',
19811979
head_type='attn_abs',
1982-
#head_hidden_size=768,
19831980
),
19841981

19851982
resnet50x64_clip=ByoModelCfg(
@@ -1996,7 +1993,6 @@ def _init_weights(module, name='', zero_init_last=False):
19961993
downsample='avg',
19971994
aa_layer='avg',
19981995
head_type='attn_abs',
1999-
#head_hidden_size=1024,
20001996
),
20011997

20021998
resnet50_mlp=ByoModelCfg(

0 commit comments

Comments
 (0)