@@ -42,7 +42,7 @@ def __init__(
42
42
qkv_bias : bool = True ,
43
43
qkv_separate : bool = False ,
44
44
pool_type : str = 'token' ,
45
- avg_token : bool = True ,
45
+ class_token : bool = False ,
46
46
drop_rate : float = 0. ,
47
47
):
48
48
super ().__init__ ()
@@ -63,6 +63,11 @@ def __init__(
63
63
self .scale = self .head_dim ** - 0.5
64
64
self .fused_attn = use_fused_attn ()
65
65
66
+ if class_token :
67
+ self .cls_token = nn .Parameter (torch .zeros (1 , embed_dim ))
68
+ else :
69
+ self .cls_token = None
70
+
66
71
if qkv_separate :
67
72
self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
68
73
self .k = nn .Linear (in_features , embed_dim , bias = qkv_bias )
@@ -109,7 +114,10 @@ def forward(self, x, pre_logits: bool = False):
109
114
B , _ , H , W = x .shape
110
115
N = H * W
111
116
x = x .flatten (2 ).transpose (1 , 2 )
112
- x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
117
+ if self .cls_token is None :
118
+ x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
119
+ else :
120
+ x = torch .cat ([self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
113
121
if self .qkv is None :
114
122
q = self .q (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
115
123
k = self .k (x ).reshape (B , N + 1 , self .num_heads , self .head_dim ).transpose (1 , 2 )
@@ -130,7 +138,6 @@ def forward(self, x, pre_logits: bool = False):
130
138
attn = attn .softmax (dim = - 1 )
131
139
x = attn @ v
132
140
x = x .transpose (1 , 2 ).reshape (B , N + 1 , - 1 )
133
- x = x [:, 0 ]
134
141
x = self .drop (x )
135
142
if pre_logits :
136
143
x = self ._pool (x , H , W )
@@ -162,7 +169,7 @@ def __init__(
162
169
qkv_bias : bool = True ,
163
170
qkv_separate : bool = False ,
164
171
pool_type : str = 'token' ,
165
- learned_token : bool = False ,
172
+ class_token : bool = False ,
166
173
drop_rate : float = 0. ,
167
174
):
168
175
super ().__init__ ()
@@ -184,10 +191,10 @@ def __init__(
184
191
self .scale = self .head_dim ** - 0.5
185
192
self .fused_attn = use_fused_attn ()
186
193
187
- if learned_token :
188
- self .token = nn .Parameter (torch .zeros (1 , embed_dim ))
194
+ if class_token :
195
+ self .cls_token = nn .Parameter (torch .zeros (1 , embed_dim ))
189
196
else :
190
- self .token = None
197
+ self .cls_token = None
191
198
192
199
if qkv_separate :
193
200
self .q = nn .Linear (in_features , embed_dim , bias = qkv_bias )
@@ -239,10 +246,10 @@ def forward(self, x, pre_logits: bool = False):
239
246
B , _ , H , W = x .shape
240
247
N = H * W
241
248
x = x .flatten (2 ).transpose (1 , 2 )
242
- if self .token is not None :
243
- x = torch .cat ([self .token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
244
- else :
249
+ if self .cls_token is None :
245
250
x = torch .cat ([x .mean (1 , keepdim = True ), x ], dim = 1 )
251
+ else :
252
+ x = torch .cat ([self .cls_token .expand (x .shape [0 ], - 1 , - 1 ), x ], dim = 1 )
246
253
pos_embed = resample_abs_pos_embed (self .pos_embed .unsqueeze (0 ), (H , W ), num_prefix_tokens = 1 )
247
254
x = x + pos_embed
248
255
0 commit comments